base_flatten.py 5.53 KB
title = "Flatten"
tip = "Corrects baseline effects by smoothing the data perpendicular to the scanning direction"
onein = True

import numpy as np
from scipy.ndimage import gaussian_filter1d

from guidata.qt.QtGui import QMessageBox
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, StringItem, ChoiceItem, FloatItem, BoolItem)
from guiqwt.config import _

def median(x):
    xl = list(x)
    xl.sort()
    return xl[len(xl)/3]

class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))

    def compute_app(self, **args):
        class FuncParam(DataSet):
            #s = StringItem('s', default="string")
            itera = IntItem('Iteration', default=2, max=10, min=1)
            order = IntItem('Order', default=7, max=13, min=1)
            clip  = FloatItem('Clip', default=1.0, min=-1.0)
            adjust = BoolItem('Adjust', default=False)
            #damp  = FloatItem('Damp', default=0.8, max=1.0, min=0.5)
            #a = FloatItem('a', default=1.)
            #choice = ChoiceItem("Unit", ("Degree", "Arcmin", "Arcsec"), default=2)
        name = title.replace(" ", "")
        if args == {}:
           #param = FuncParam(_("Press"), "Apply a n-order polynomial fit in scanning direction")
           param = FuncParam(_(title), "Apply a n-order polynomial fit in scanning direction")
        else:
           param = self.parent.ScriptParameter(name, args)

        # if no parameter needed set param to None. activate next line
        #param = None
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein) 

    def adjust(self, m, proz=0.1):
        rows, cols = m.data.shape
        if not 'SCANDIR' in m.header:
        #if not 'SCANDIR' in m.header:
           self.Error("sorry: keyword 'SCANDIR' not found in FITS header")
           return [], None
        if m.header['SCANDIR'] in ("LON", "ALON", "ULON", "GLON", "RA"):
           l = int(cols*proz) - int(cols*proz-1)%2
           for row in range(rows):
               x = m.data[row]
               a = x[:l]
               b = x[-l:]
               amask = np.isnan(a)
               bmask = np.isnan(b)
               if False in amask and False in bmask:
                  am = median(a[-amask])
                  bm = median(b[-bmask])
                  #ax = list(a).index(am)
                  #bx = len(x[:-l]) + list(b).index(bm)
                  ax = float(l)/2.0
                  bx = len(x)-ax
                  s = (am-bm)/(ax-bx)
                  o = am - s*ax
                  m.data[row] -= o + s*np.arange(cols)
        else:
           l = int(rows*proz) - int(rows*proz-1)%2
           for col in range(cols):
               x = m.data[:,col]
               a = x[:l]
               b = x[-l:]
               amask = np.isnan(a)
               bmask = np.isnan(b)
               if False in amask and False in bmask:
                  am = median(a[-amask])
                  bm = median(b[-bmask])
                  #ax = list(a).index(am)
                  #bx = len(x[:-l]) + list(b).index(bm)
                  ax = float(l)/2.0
                  bx = len(x)-ax
                  s = (am-bm)/(ax-bx)
                  o = am - s*ax
                  m.data[:,col] -= o + s*np.arange(rows)
        return m.data

    def function(self, m, p):
        if not 'SCANDIR' in m.header:
           class ScanParam(DataSet):
                 scandir = ChoiceItem("SCANDIR", (("LON", "LON"), ("LAT", "LAT")),
                           default="LON")
           param = ScanParam(_(title), "Scan direction")
           if not param.edit():
              return [], p
           m.header['SCANDIR'] = param.scandir
        #if not 'SCANDIR' in m.header:
           #self.Error("sorry: keyword 'SCANDIR' not found in FITS header")
           #return [], None
        if m.header['SCANDIR'] in ("LON", "ALON", "ULON", "GLON", "RA"):
           axis = 0
        else:
           axis = 1
        if p.clip > 0: w = p.clip
        else: w = 1.e9
        rows, cols = m.data.shape
        if p.adjust:
           data = self.adjust(m, proz=0.25)
        else:
           data = m.data.copy()
        mask1 = np.isnan(data)
        gdata = data.copy()
        gdata[mask1] = np.interp(np.flatnonzero(mask1), np.flatnonzero(~mask1), data[~mask1])
        for i in range(p.itera):
            diff = gdata - gaussian_filter1d(gdata, 1.5, axis=axis, order=0)
            diff[mask1] = np.nan
            mask = np.isnan(diff)
            med = np.median(diff[~mask])
            sig = np.std(diff[~mask])
            if axis == 0:
               x = np.arange(cols)
               for row in range(rows):
                   mask = (diff[row] > med-w*sig) & (diff[row] < med+w*sig)
                   if len(x[mask]) <= p.order: break
                   poly = np.poly1d(np.polyfit(x[mask], diff[row][mask], p.order))
                   data[row] -= poly(x)
            elif axis == 1:
               y = np.arange(rows)
               for col in range(cols):
                   mask = (diff[:,col] > med-w*sig) & (diff[:,col] < med+w*sig)
                   if len(y[mask]) <= p.order: break
                   poly = np.poly1d(np.polyfit(y[mask], diff[:,col][mask], p.order))
                   data[:,col] -= poly(y)
            gdata = 1*data
            w *= 0.95
        if m.header['SCANDIR'] in ("LON", "LAT"): m.header.__delitem__('SCANDIR')
        m.data = data
        return m, p