base_twist.py 3.33 KB
title = "Twist"
tip = "Corrects background effects by fitting a plane"
onein = True

import numpy as np
from scipy.ndimage import gaussian_filter1d
import scipy.optimize

from guidata.qt.QtGui import QMessageBox
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, StringItem, ChoiceItem, FloatItem, BoolItem)
from guiqwt.config import _

def median(data):
    x = data.ravel()
    mask = (np.isnan(x) == False)
    xl = list(x[mask]) 
    xl.sort()
    return xl[len(xl)/3]

def fitPlaneSolve(data):
    x = np.indices(data.shape)[0].ravel()
    y = np.indices(data.shape)[1].ravel()
    z = data.ravel()
    X = []
    Y = []
    Z = []
    for i in range(len(z)):
        if not np.isnan(z[i]):
           X.append(x[i])
           Y.append(y[i])
           Z.append(z[i])
    X = np.array(X)
    Y = np.array(Y)
    Z = np.array(Z)
    npts = len(X)
    A = np.array([ [np.sum(X*X), np.sum(X*Y), np.sum(X)],
                   [np.sum(X*Y), np.sum(Y*Y), np.sum(Y)],
                   [np.sum(X),   np.sum(Y), npts] ])
    B = np.array([ [np.sum(X*Z), np.sum(Y*Z), np.sum(Z)] ])
    normal = np.linalg.solve(A,B.T)
    #nn = np.linalg.norm(normal)
    #normal = normal / nn
    a, b, c = normal.ravel()
    z = a*x + b*y + c
    return z.reshape(data.shape)

class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))

    def compute_app(self, **args):
        class FuncParam(DataSet):
            proz  = FloatItem('Percent', default=25, max=50, min=5)
            plane = BoolItem('Plane', default=False)
        name = title.replace(" ", "")
        if args == {}:
           param = FuncParam(_(title), "Apply a plane fit for background correction")
        else:
           param = self.parent.ScriptParameter(name, args)

        # if no parameter needed set param to None. activate next line
        #param = None
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein)

    def function(self, m, p):
        rows, cols = m.data.shape
        drows = int(rows*p.proz/100.0)
        dcols = int(cols*p.proz/100.0)
        data = np.nan*m.data
        data[:drows, :dcols] = m.data[:drows, :dcols]
        data[:drows, -dcols:] = m.data[:drows, -dcols:]
        data[-drows:, :dcols] = m.data[-drows:, :dcols]
        data[-drows:, -dcols:] = m.data[-drows:, -dcols:]
        #m.data = data
        if p.plane:
           m.data -= fitPlaneSolve(data)
           return m, p
        # twist
        for loop in range(3):
            m_ll = median(m.data[:drows, :dcols])
            m_lr = median(m.data[:drows, -dcols:])
            m_ul = median(m.data[-drows:, :dcols])
            m_ur = median(m.data[-drows:, -dcols:])
            l_r = 0.5*drows
            l_c = 0.5*dcols
            s_l = (m_ul-m_ll) / float(rows-drows)
            s_r = (m_ur-m_lr) / float(rows-drows)
            o_l = m_ll - s_l*l_r
            o_r = m_lr - s_l*l_r
            for row in range(rows):
                f_l = o_l + s_l * row
                f_r = o_r + s_r * row
                s_c = (f_r-f_l) / float(cols-dcols)
                o_c = f_l - s_c*l_c
                m.data[row] -= o_c + s_c*np.arange(cols)
        return m, p