base_pattern_adjust.py 5.39 KB
title = "Pattern-Adjust LB"
tip = "Corrects baseline effects by fitting polynomials in both scanning directions iteratively"
onein = False

import numpy as np

from guidata.qt.QtGui import QMessageBox
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, StringItem, ChoiceItem, FloatItem, BoolItem)
from guiqwt.config import _

class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))

    def compute_app(self, **args):
        class FuncParam(DataSet):
            #order = IntItem('Order', default=1, max=13, min=1)
            edge = IntItem('Edge', default=5, max=13, min=0)
            center = IntItem('Center', default=9, max=13, min=1)
            #autocal = BoolItem('Autocal', default=False)
        name = title.replace(" ", "")
        if args == {}:
           #param = FuncParam(_("Press"), "Apply a n-order polynomial fit in scanning direction")
           param = FuncParam(_(title), "Apply a n-order polynomial fit in both scanning directions")
        else:
           param = self.parent.ScriptParameter(name, args)

        # if no parameter needed set param to None. activate next line
        #param = None
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein) 

    def weight(self, data, sigma):
        #mask = -np.isnan(x)
        sigma2 = 2*sigma*sigma
        w = 0.0*data
        rows, cols = w.shape
        for row in range(rows):
            for col in range(cols):
                val = data[row][col]**2
                w[row][col] = np.exp(-val/sigma2) 
        return w
        #return lambda x: np.exp(-x*x/sigma2)

    def autocal(self, data1, data2, dummy):
        mask = (data1 != dummy)
        d1 = list(np.ravel(data1[mask]))
        mask = (data2 != dummy)
        d2 = list(np.ravel(data2[mask]))
        d1.sort()
        d2.sort()
        nl = max(10, len(d1)/10)
        a1, a0 = np.polyfit(d2[-nl:], d1[-nl:], 1)
        rows, cols = data2.shape
        for row in range(rows):
            for col in range(cols):
                if data2[row, col] != dummy:
                   data2[row, col] = a0 + a1*data2[row, col]
        return data1, data2

    def adjust_lb(self, map1, map2, sdir, p):
        rows, cols = map1.shape
        dummy = np.nan
        x = np.arange(cols)
        y = np.arange(rows)
        order = 1
        edge = p.edge
        center = p.center
        if sdir == 0:
           for row in range(rows):
               mask = (map1[row] != dummy) & (map2[row] != dummy)
               if len(x[mask]) > order:
                  diff = map2[row] - map1[row]
                  diff[:edge] = dummy
                  diff[-edge:] = dummy
                  diff[cols/2-center/2:cols/2+center/2] = dummy 
                  mask = np.isnan(diff)
                  mask = (mask == False)
                  poly = np.poly1d(np.polyfit(x[mask], diff[mask], order))
                  map1[row] -= poly(x)
        else:
           for col in range(cols):
               mask = (map1[:,col] != dummy) & (map2[:,col] != dummy)
               if len(y[mask]) > order:
                  diff = map2[:,col] - map1[:,col]
                  diff[:edge] = dummy
                  diff[-edge:] = dummy
                  diff[rows/2-center/2:rows/2+center/2] = dummy 
                  mask = np.isnan(diff)
                  mask = (mask == False)
                  poly = np.poly1d(np.polyfit(y[mask], diff[mask], order))
                  map1[:,col] -= poly(y)
        return map1

    def function(self, ms, p):
        dummy = 0.0
        lon = 0
        lat = 0
        lon_data = []
        lat_data = []
        autocal = False
        for m in ms:
            data, w = self.parent.nan_check(m.data, dummy, weight=True)
            if m.header['SCANDIR'] in ('ALON', 'XLON', 'ULON', 'GLON', 'RA', 'HA'):
               if lon_data == []:
                  lon_data = data
                  wlon = w
               else:
                  lon_data += data
                  wlon += w
               lon += 1
            elif m.header['SCANDIR'] in ('ALAT', 'XLAT', 'ULAT', 'GLAT', 'DEC'):
               if lat_data == []:
                  lat_data = data
                  wlat = w
               else:
                  lat_data += data
                  wlat += w
               lat += 1
            else:
               self.Error(str("sorry, SCANDIR=%s not defined" % m.header['SCANDIR']))
               return [], p
        if lon == 0 or lat == 0:
           if lon == 0: sd = "Longitude"
           else: sd = "Latitude"
           self.Error(str("sorry, missing maps scanning direction %s" % sd))
           return [], p
        lon_data, w1 = self.parent.nan_check(lon_data/wlon, dummy, weight=True)
        lat_data, w2 = self.parent.nan_check(lat_data/wlat, dummy, weight=True)
        if autocal: lon_data, lat_data = self.autocal(lon_data, lat_data, dummy)
        data_l = lon_data/w1
        data_b = lat_data/w2
        data_l = self.adjust_lb(data_l, data_b, 1, p)
        data_b = self.adjust_lb(data_b, data_l, 0, p)
        m.data = (w1*data_l + w2*data_b) / (w1+w2)
        #if not m.header['SCANDIR'] in ("ALON", "ALAT"): m.header.__delitem__('SCANDIR')
        #if not m.header['SCANDIR'] in ("ALON", "ALAT"): del m.header['SCANDIR']
        return m, p