base_plait.py 3.87 KB
title = "Plait90"
tip = "applies sinus^2 filter to fft data"
onein = False

import numpy as np
import scipy.signal as sps

from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, FloatArrayItem, StringItem,
                                       ChoiceItem, FloatItem, DictItem,
                                       BoolItem)
from guiqwt.config import _

class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def compute_app(self):
        class Param(DataSet):
            #angle = FloatItem('Angle', default=0.0)
            scale = FloatItem('Scale', default=0.15, min=0.0, max=1.0)
            #scale = FloatItem('Scale', default=5, min=0.0, max=100.0)
        param = Param(_("Plait"), "Set scale lenght of scanning")
        name = title.replace(" ", "")
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein) 

    def wtfun(self, z, scale):
        wt = 1.0
        zsc = z*scale
        if(abs(zsc) >= 1.0): return wt
        thet = zsc*np.pi/2.0
        wt = np.sin(thet)**2
        if abs(wt) < 0.1: wt = 0.1
        return wt

    def wtfun0(self, z, scale):
        wt = 1.0
        zsc = z*scale
        if(abs(zsc) >= 1.0): return wt
        thet = zsc*np.pi/2.0
        wt = np.sin(thet)**2
        return wt

    def ftnorm(self, m, fftdata):
        irows, icols = fftdata.shape 
        xsamp = abs(m.header['CDELT1'])
        ysamp = abs(m.header['CDELT2'])
        nfi = icols/2
        nfj = irows/2
        f1i = 1.0/(xsamp*float(icols-1))
        f1j = 1.0/(ysamp*float(irows-1))
        for j in range(irows):
            iy = j - (j/(nfj+1))*irows
            y = float(iy)*f1j
            for i in range(icols):
                ix = i - (i/(nfi+1))*icols
                x = float(ix)*f1i
                sumwt = 0.0
                for n in range(len(self.cosphi)):
                    z = x*self.cosphi[n] + y*self.sinphi[n]
                    sumwt += self.wtfun(z, self.scale[n])
                fftdata[j][i] /= sumwt
        return fftdata

    def ftwt(self, m, p):
        fftdata = np.fft.fft2(m.data) 
        irows, icols = m.data.shape 
        if m.header['SCANDIR'] in ('RA', 'HA', 'ALON', 'GLON', 'XLON'):
           angle = 0.0
           scale = icols*abs(m.header['CDELT1'])*p.scale
           #scale = abs(m.header['CDELT1'])*p.scale
        else:
           angle = 90.0
           scale = irows*abs(m.header['CDELT2'])*p.scale
           #scale = abs(m.header['CDELT2'])*p.scale
        phi = angle * np.pi/180.0
        xsamp = abs(m.header['CDELT1'])
        ysamp = abs(m.header['CDELT2'])
        nfi = icols/2
        nfj = irows/2
        f1i = 1.0/(xsamp*float(icols-1))
        f1j = 1.0/(ysamp*float(irows-1))
        sinphi = np.sin(phi)
        cosphi = np.cos(phi)
        self.sinphi.append(sinphi)
        self.cosphi.append(cosphi)
        self.scale.append(scale)
        for j in range(irows):
            iy = j - (j/(nfj+1))*irows
            y = float(iy)*f1j
            for i in range(icols):
                ix = i - (i/(nfi+1))*icols
                x = float(ix)*f1i
                z = x*cosphi + y*sinphi
                wt = self.wtfun(z, scale)
                fftdata[j][i] *= wt
        return fftdata

    def function(self, ms, p):
        sum_fftdata = []
        self.sinphi = []
        self.cosphi = []
        self.scale = []
        for m in ms:
            if m.header['MAPTYPE'][0] == "I":
               fftdata = self.ftwt(m, p) 
               if sum_fftdata == []:
                  sum_fftdata = fftdata
               else:
                  sum_fftdata += fftdata
        fftdata = self.ftnorm(m, sum_fftdata)
        m.data = np.fft.ifft2(fftdata).real
        if not m.header['SCANDIR'] in ("ALON", "ALAT"): m.header.__delitem__('SCANDIR')
        if m.header['SCANNUM']: m.header.__delitem__('SCANNUM')
        return m, p