base_basketw.py 5.81 KB
title = "BasketWeaving"
tip = "Corrects baseline effects by smoothing in both scanning directions"
onein = False

import numpy as np
from nodmath import smooth, nan_interpolation

from guidata.qt.QtGui import QMessageBox
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, StringItem, ChoiceItem, FloatItem, BoolItem)
from guiqwt.config import _

from nodmath import nanmedian

class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))

    def compute_app(self, **args):
        class FuncParam(DataSet):
            #window = IntItem('Order', default=7, min=1)
            clip = FloatItem('SigmaClip', default=-1, max=10.0, min=-1.0)
            #iters = IntItem('Iteration', default=3, min=1, max=10)
        name = title.replace(" ", "")
        if args == {}:
           #param = FuncParam(_("Press"), "Apply a n-window polynomial fit in scanning direction")
           param = FuncParam(_(title), "Apply a smoothing fit in both scanning directions")
        else:
           param = self.parent.ScriptParameter(name, args)

        # if no parameter needed set param to None. activate next line
        #param = None
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein) 

    def sig_clip(self, a, clip, n):
        if clip > 0:
           for i in range(n):
               c = clip*np.std(a)
               a = np.where(np.abs(a) > c, np.nan, a)
        return a
        
    def weight(self, data, sigma):
        #mask = -np.isnan(x)
        sigma2 = 2*sigma*sigma
        w = 0.0*data
        rows, cols = w.shape
        for row in range(rows):
            for col in range(cols):
                val = data[row][col]**2
                w[row][col] = np.exp(-val/sigma2) 
        return w

    def Rowfit(self, diff, window):
        rows, cols = diff.shape
        pfit_row = np.zeros((rows, cols))
        x = np.arange(cols)
        for row in range(rows):
            mask = ~np.isnan(diff[row])
            if len(x[mask]) > window:
               prow = smooth(diff[row], window)
               #pfit_row[row] = np.where(mask, prow, np.nan)
               pfit_row[row] = prow
        return pfit_row

    def Colfit(self, diff, window):
        rows, cols = diff.shape
        pfit_col = np.zeros((rows, cols))
        y = np.arange(rows)
        for col in range(cols):
            mask = ~np.isnan(diff[:,col])
            if len(y[mask]) > window:
               pcol = smooth(diff[:,col], window)
               #pfit_col[:,col] = np.where(mask, pcol, np.nan)
               pfit_col[:,col] = pcol
        return pfit_col

    def function(self, ms, p):
        #if p == None:
        #   iters = 10
        #else:
        #   iters = p.iters
        dummy = 0.0
        lon = 0
        lat = 0
        lon_data = []
        lat_data = []
        meanEL = []
        for m in ms:
            data, w = self.parent.nan_check(m.data, dummy, weight=True)
            if "SCANDIR" not in m.header:
               self.Error("missing SCANDIR keyword in header")
               return [], p
            if "MEANEL" in m.header:
               meanEL.append(m.header['MEANEL'])
            if m.header['SCANDIR'] in ('ALON', 'XLON', 'ULON', 'GLON', 'RA', 'HA'):
               if lon_data == []:
                  lon_data = data
                  wlon = w
               else:
                  lon_data += data
                  wlon += w
               lon += 1
            elif m.header['SCANDIR'] in ('ALAT', 'XLAT', 'ULAT', 'GLAT', 'DEC'):
               if lat_data == []:
                  lat_data = data
                  wlat = w
               else:
                  lat_data += data
                  wlat += w
               lat += 1
            else:
               self.Error(str("sorry, SCANDIR=%s not defined" % m.header['SCANDIR']))
               return [], p
        if lon == 0 or lat == 0:
           if lon == 0: sd = "Longitude"
           else: sd = "Latitude"
           self.Error(str("sorry, missing maps scanning direction %s" % sd))
           return [], p
        lon_data, w1 = self.parent.nan_check(lon_data/wlon, dummy, weight=True)
        lat_data, w2 = self.parent.nan_check(lat_data/wlat, dummy, weight=True)
        w = w1 + w2
        wc = np.where(w < 1.0, 1.0, w)
        rows, cols = lon_data.shape
        windowL = cols/2
        windowB = rows/2
        minwin = 5
        window = max(windowL, windowB)
        window -= 1-window%2
        while window >= minwin:
            if window < minwin: break
            if windowB >= minwin:
               diff = (lon_data - lat_data) / wc
               diff = self.sig_clip(diff, p.clip, 3)
               pfit_row = self.Rowfit(diff, windowB)
               lon_data -= 2*pfit_row
               windowB -= 4
            if windowL >= minwin:
               diff = (lat_data - lon_data) / wc
               diff = self.sig_clip(diff, p.clip, 3)
               pfit_col = self.Colfit(diff, windowL)
               lat_data -= 2*pfit_col
               windowL -= 4
            window -= 4
        m.data = (lon_data + lat_data) / w
        if not m.header['SCANDIR'] in ("ALON", "ALAT"): m.header.__delitem__('SCANDIR')
        if 'SCANNUM' in m.header and m.header['SCANNUM']: m.header.__delitem__('SCANNUM')
        if 'MAPTYPE' in m.header:
           m.header['MAPTYPE'] = m.header['MAPTYPE'].replace("I1", "I")
           m.header['MAPTYPE'] = m.header['MAPTYPE'].replace("I2", "I")
        if 'EXTNAME' in m.header:
           m.header['EXTNAME'] =  m.header['EXTNAME'].replace("I1", "I")
           m.header['EXTNAME'] =  m.header['EXTNAME'].replace("I2", "I")
        if meanEL != []:
           m.header['MEANEL'] = np.mean(meanEL)
        return m, p