proc_tnt.py 6.29 KB
title = "TnT"
tip = "seperates thermal and non-thermal emission"
onein = 2

import numpy as np
import copy as cp
import scipy.ndimage as spi

from guidata.qt.QtGui import QMessageBox
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, FloatArrayItem, StringItem,
                                       ChoiceItem, FloatItem, DictItem,
                                       BoolItem)
from guiqwt.config import _
from nodfitting import gaussian, correlate
from nodmath import nan_interpolation, map_zoom, same_size, register_translation, subpixel_shift

def nextpow2(i):
    n = 1
    while n < i: n *= 2
    return n

class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))

    def compute_app(self, **args):
        class FuncParam(DataSet):
            nonth = FloatItem('Non-thermal index', default=1.0, min=0.1)
            freq = FloatItem('Frequency', default=6.0, min=0.0)
            proj = ChoiceItem("Projection", (("CAR", "CAR"), #("ARC", "ARC"), 
                                             ("SIN", "SIN"),
                                             ("TAN", "TAN"), ("NCP", "NCP"), ("SFL", "SFL"),
                                             ("AIT", "AIT")))
        name = title.replace(" ", "")
        if args == {}:
           param = FuncParam(_(title), "Image Merging:")
        else:
           param = self.parent.ScriptParameter(name, args)
        # if no parameter needed set param to None. activate next line
        #param = None
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein) 

    def same_proj(self, m1, m2, proj):
        m1.header['CTYPE1'] = m1.header['CTYPE1'][:-3] + proj
        m1.header['CTYPE2'] = m1.header['CTYPE2'][:-3] + proj
        m2.header['CTYPE1'] = m2.header['CTYPE1'][:-3] + proj
        m2.header['CTYPE2'] = m2.header['CTYPE2'][:-3] + proj
        return m1, m2

    def convolve(self, m1, m2):
        f = 2.0*np.sqrt(2.0*np.log(2.0))
        hpbwx1 = m1.header["BMAJ"]
        hpbwy1 = m1.header["BMIN"]
        hpbwx2 = m2.header["BMAJ"]
        hpbwy2 = m2.header["BMIN"]
        delx1 = abs(m1.header['CDELT1'])
        dely1 = abs(m1.header['CDELT2'])
        sigmax = np.sqrt((hpbwx2/delx1)**2 - (hpbwx1/delx1)**2) / f
        sigmay = np.sqrt((hpbwy2/dely1)**2 - (hpbwy1/dely1)**2) / f
        fac = (hpbwx2*hpbwy2) / (hpbwx1*hpbwy1)
        m1.data = spi.gaussian_filter(m1.data, (sigmay, sigmax)) * fac
        m1.header['BMAJ'] = m2.header['BMAJ']
        m1.header['BMIN'] = m2.header['BMIN']
        return m1
 
    def resample(self, m1, m2):
        dx1, dy1 = m1.header['CDELT1'], m1.header['CDELT2']
        dx2, dy2 = m2.header['CDELT1'], m2.header['CDELT2']
        rebin = (abs(dx2/dx1), abs(dy2/dy1))
        m2.data = map_zoom(m2.data, rebin, order=3, prefilter=True)
        m2.header['NAXIS2'], m2.header['NAXIS1'] = m2.data.shape
        m2.header['CRPIX1'] = (m2.header['CRPIX1'] - 0.5) * rebin[1] + 0.5
        m2.header['CRPIX2'] = (m2.header['CRPIX2'] - 0.5) * rebin[0] + 0.5
        m2.header['CDELT1'] /= rebin[1]
        m2.header['CDELT2'] /= rebin[0]
        return m2

    def seperate_tnt_1(self, m1, m2, m3, bth, bnt, fghz3):
        fghz1 = m1.header['CRVAL3']*1.e-9
        fghz2 = m2.header['CRVAL3']*1.e-9
        d = 1.0 / ((fghz1*fghz2)**bth - (fghz1*fghz2)**bnt)
        data_th = (d*fghz2**bnt) * m1.data - (d*fghz1**bnt) * m2.data
        data_nt = (d*fghz1**bth) * m2.data - (d*fghz2**bth) * m1.data
        data = data_th*fghz3**bth + data_nt*fghz3**bnt
        return data_th, data_nt, data

    def seperate_tnt(self, m1, m2, m3, bth, bnt, fghz3):
        fghz1 = m1.header['CRVAL3']*1.e-9
        fghz2 = m2.header['CRVAL3']*1.e-9
        x = fghz1/fghz2
        Cth = x**bnt
        Cnt = x**bth
        C = 1.0 / (Cnt - Cth)
        data_th = +C * (m1.data - Cth*m2.data)
        data_nt = -C * (m1.data - Cnt*m2.data)
        m3.header = m1.header.copy()
        m3.header['CRVAL3'] = fghz3*1.e9
        xn = fghz3/fghz2
        m3.data = data_th * xn**bth + data_nt * xn**bnt
        return data_th, data_nt, m3.data

    def center_of_gravity(self, data1, data2, clip=None):
        mask1 = data1 > clip*np.nanmax(data1)
        mask2 = data2 > clip*np.nanmax(data2)
        d1 = np.where(mask1*mask2, data1, np.nan)
        d2 = np.where(mask1*mask2, data2, np.nan)
        rows, cols = d1.shape
        x, y = np.mgrid[:rows,:cols]
        b1x = np.nansum(d1*x)/np.nansum(d1)
        b1y = np.nansum(d1*y)/np.nansum(d1)
        rows, cols = d2.shape
        x, y = np.mgrid[:rows,:cols]
        b2x = np.nansum(d2*x)/np.nansum(d2)
        b2y = np.nansum(d2*y)/np.nansum(d2)
        return b2x-b1x, b2y-b1y

    def subpixel_shift(self, data1, data2):
        #shift, error, diffphase = register_translation(data1, data2, 1000)
        #print shift, error
        shift = center_of_gravity(data1, data2, clip=0.1)        
        data2 = nan_interpolation(data2)
        return np.fft.ifft2(spi.fourier_shift(np.fft.fft2(data2), shift)).real

    def function(self, ms, p):
        hpbw = []
        i = 0
        for m in ms:
            hpbw.append(np.sqrt(m.header['BMAJ'] + m.header['BMIN']))
            ms[i].data = nan_interpolation(m.data)
            i += 1
        if hpbw[0] > hpbw[1]:
           ms.reverse()
        m1 = ms[0] # high freq
        m2 = ms[1] # low  freq
        m3 = cp.copy(m1)
        m1, m2 = self.same_proj(m1, m2, p.proj)
        m1 = self.convolve(m1, m2)
        m2 = self.resample(m1, m2)
        #return [m1, m2], p
        out, ms = same_size([m1, m2], True)
        m1 = ms[0]
        m2 = ms[1]
        if out > 0:
           self.Error("one or more maps have different scales or coordinate systems")
           return [], p
        #return [m1, m2], p
        m1.data = nan_interpolation(m1.data)
        m2.data = nan_interpolation(m2.data)
        dx, dy, m1.data = subpixel_shift(m2.data, m1.data)
        m1.data, m2.data, m3.data = self.seperate_tnt(m1, m2, m3, -0.1, -abs(p.nonth), p.freq)
        m1.header["MAPTYPE"] = "TH"
        m2.header["MAPTYPE"] = "NT"
        m1.header["CRVAL3"] = 1e9
        m2.header["CRVAL3"] = 1e9
        m3.header["MAPTYPE"] = "I"
        return [m1, m2, m3], p