filter_immerg.py 5.22 KB
title = "ImMerge"
tip = "add missing flux to maps"
onein = 2

import numpy as np
#from scipy import fftpack
import scipy.ndimage as spi

from guidata.qt.QtGui import QMessageBox
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, FloatArrayItem, StringItem,
                                       ChoiceItem, FloatItem, DictItem,
                                       BoolItem)
from guiqwt.config import _
from nodfitting import gaussian, correlate
from nodmath import nan_interpolation, subpixel_shift

def nextpow2(i):
    n = 1
    while n < i: n *= 2
    return n

class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))

    def compute_app(self, **args):
        class FuncParam(DataSet):
              center = BoolItem(u"Center:", default=False)
              improve = BoolItem(u"Improve:", default=False)
        name = title.replace(" ", "")
        if args == {}:
           param = FuncParam(_(title), "Image Merging:")
        else:
           param = self.parent.ScriptParameter(name, args)
        # if no parameter needed set param to None. activate next line
        #param = None
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein) 

    def def_ring(self, shape, r):
        rows, cols = shape
        r1 = rows/2
        r2 = rows - r1
        c1 = cols/2
        c2 = cols - c1
        y,x = np.ogrid[-r1:+r2, -c1:+c2]
        mask = x*x + y*y <= r*r
        print mask.shape, shape
        return mask
 
    def weight(self, fac, ampl, phas, beam):
        amp = np.fft.ifftshift(np.where(ampl[0] < fac*ampl[1], fac*ampl[1], ampl[0]))
        ang = np.fft.ifftshift(np.where(ampl[0] < fac*ampl[1], phas[1], phas[0]))
        data = np.fft.ifft2(amp)
        data.real = amp*np.cos(ang)
        data.imag = amp*np.sin(ang)
        ifftdata = np.fft.ifft2(data)
        iamp = np.abs(ifftdata)
        iang = np.angle(ifftdata)
        return iamp*np.cos(iang) + iamp*np.sin(iang)

    def get_beam(self, header):
        b =  (header['BMAJ'] + header['BMIN']) / 2.0
        b /= (2.0*np.sqrt(2.0*np.log(2.0)))
        b /= (abs(header['CDELT1']) + abs(header['CDELT2'])) / 2.0
        return b

    def polyfit(self, data1, data2, fmax=0.0, out=False, sort=False, order=2):
        d1 = 1*np.ravel(data1)
        d2 = 1*np.ravel(data2)
        if d1.max() < 1:
           self.Error("numbers too small, please rescale data")
           if out:
              return None, None, None
           else:
              return None
        if sort:
           d1.sort()
           d2.sort()
        mask1 = np.where((data1 < fmax) | np.isnan(data1), True, False)
        mask2 = np.where((data2 < fmax) | np.isnan(data2), True, False)
        mask = ~np.array([any(tup) for tup in zip(mask1.ravel(), mask2.ravel())])
        d1 = d1[mask]
        d2 = d2[mask]
        pf = np.polyfit(d1, d2, order)
        if out:
           return pf, d1, d2
        else:
           return pf

    def correlate_plot(self, data1, data2, fmax=0.0, plot=True, order=2):
        from guiqwt import pyplot
        ##a, b, d1, d2 = correlate(data1, data2, fmax=fmax, out=True)
        pf, d1, d2 = self.polyfit(data1, data2, fmax=fmax, out=True, order=order)
        self.pf = pf
        #print pf
        x = np.arange(np.nanmin(d1), np.nanmax(d1))
        p = np.poly1d(pf)
        if not plot:
           return p
        # plotting
        self.parent.fig = pyplot.figure("TT-plot")
        fit = []
        #for i in range(len(pf), 0, -1):
        lpf = len(pf)
        for i in range(lpf):
            fit.append(str(_(u"x^%d : %10.3g \r") % (i, pf[lpf-i-1])))
        pyplot.plot(d1, d2, "g+", label="Data points")
        pyplot.legend(pos="TL")
        pyplot.plot(x, p(x), "r-", label=fit[0])
        for i in range(1, lpf):
            pyplot.plot(x[:2], p(x)[:2], "  ", label=fit[i])
        pyplot.ylabel("VLA")
        pyplot.xlabel("Effelsberg")
        pyplot.zlabel("TT-plot")
        pyplot.show(mainloop=False)
        return p

    def function(self, ms, p):
        hpbw = []
        for m in ms:
            hpbw.append(np.sqrt(m.header['BMAJ'] * m.header['BMIN']))
        if hpbw[0] > hpbw[1]:
           ms.reverse()
        ampl = []
        phas = []
        beam = []
        masks = []
        i = 0
        for m in ms:
            masks.append(np.isnan(m.data))
            data = nan_interpolation(m.data)
            ms[i].data = data
            beam.append(self.get_beam(m.header))
            i += 1
        fac = beam[0]**2 / beam[1]**2
        m = ms[0]
        m1 = ms[1]
	b = np.sqrt(beam[1]**2 - beam[0]**2)
        smooth = spi.gaussian_filter(m.data, (b, b)) / fac
        if p.center: 
           dx, dy, m1.data = subpixel_shift(smooth, m1.data, clip=0.1)
        m.data += fac*(m1.data - smooth)
        if p.improve:
           smooth = spi.gaussian_filter(m.data, (b, b)) / fac
           m.data += fac*(m1.data - smooth)
        mask = np.array([any(tup) for tup in zip(masks[0].ravel(), masks[1].ravel())])
        mask = mask.reshape(m.data.shape)
        m.data = np.where(mask, np.nan, m.data)
        return m, p