pol_zzzPIrms.py 5.58 KB
title = "PI rms"
tip = "estimate the noise (rms) of the PI map"
onein = 2

import numpy as np
import scipy.stats as ss
from scipy import optimize, signal
from scipy.special import iv
try:
   from nodfitting import curve_fit
except:
   from scipy.optimize import curve_fit

from guiqwt import pyplot
from guidata.qt.QtGui import QMessageBox
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, StringItem, ChoiceItem, FloatItem, BoolItem)
from guiqwt.config import _

import warnings
warnings.filterwarnings("ignore")

class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))

    def compute_app(self, **args):
        class FuncParam(DataSet):
            s = StringItem('s', default="string")
            i = IntItem('i', default=0, max=100, min=0)
            a = FloatItem('a', default=1.)
            b = BoolItem("bool", default=True)
            choice = ChoiceItem("Unit", ("Degree", "Arcmin", "Arcsec"), default=2)
        name = title.replace(" ", "")
        if args == {}:
           param = FuncParam(_(title), "description")
        else:
           param = self.parent.ScriptParameter(name, args)

        # if no parameter needed set param to None. activate next line
        param = None
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein) 

    def plot(self, x, y1, y2, p, f):
        y = y1-y2
        self.parent.fig = pyplot.figure("Noise distribution")
        sig = str("Sigma = %f" % abs(p[-1]))

        pyplot.subplot(1, 1, 1)
        pyplot.plot(x, y, "b*", label="Data")
        pyplot.legend()
        pyplot.plot(x, f(x, *p), "r-", label=sig)
        #pyplot.xlabel("Intensity")
        #pyplot.ylabel("#bins")
        pyplot.zlabel("Noise distribution")
        
        pyplot.subplot(1, 1, 2)
        pyplot.legend()
        pyplot.plot(x, y1, "g-", label="PI")
        pyplot.plot(x, y2, "b--", label="PI*")

        pyplot.show(mainloop=False)

    def get_rms(self, data):
        mask = ~np.isnan(data)
        hdata = data[mask]
        med = 2.5*np.median(hdata)
        bins = 64
        hdata = np.where(hdata < med, hdata, med)
        histo = np.histogram(hdata, bins=bins)
        lh = len(histo[0])
        for i in range(len(histo[0])):
            if histo[1][i] < med: lh = i
        rms = 0.0
        imax = 0
        for i in range(lh):
            if histo[0][i] > imax:
               rms = histo[1][i]
               imax = histo[0][i]
        poly = np.poly1d(np.polyfit(histo[1][:lh], histo[0][:lh], 7))
        imax = poly(histo[1]).argmax()
        rms = histo[1][imax]
        roots = np.roots(poly.deriv())
        for i in range(len(roots)):
            r = roots[i]
            if r.imag == 0.0 and r.real > histo[1][1] and r.real < histo[1][lh]:
               rms = r.real
        return rms

    def histo(self, data1, data2, rms):
        ny, nx = data1.shape
        mask = ~np.isnan(data1)
        hdata1 = data1[mask]
        hdata2 = data2[mask]
        bins = 64
        rmin = -4*rms
        rmax = +8*rms
        hdata1 = np.where(hdata1 == hdata1.min(), rmin, hdata1)
        hdata1 = np.where(hdata1 == hdata1.max(), rmax, hdata1)
        hdata1 = np.where(hdata1 < rmin, rmin, hdata1)
        hdata1 = np.where(hdata1 > rmax, rmax, hdata1)
        hdata2 = np.where(hdata2 == hdata2.min(), rmin, hdata2)
        hdata2 = np.where(hdata2 == hdata2.max(), rmax, hdata2)
        hdata2 = np.where(hdata2 < rmin, rmin, hdata2)
        hdata2 = np.where(hdata2 > rmax, rmax, hdata2)
        mask = ~np.isnan(hdata1)
        hist1, xb1 = np.histogram(hdata1[mask], bins=bins)
        mask = ~np.isnan(hdata2)
        hist2, xb2 = np.histogram(hdata2[mask], bins=bins)
        x = []
        y1 = []
        y2 = []
        for i in range(len(xb2)):
            if xb2[i] > -3*rms and xb2[i] < 6*rms:
               x.append(xb2[i])
               y2.append(hist2[i])
            if xb1[i] > -3*rms and xb1[i] < 6*rms:
               y1.append(hist1[i])
        return np.array(x), np.array(y1), np.array(y2)

    def ricegauss(self, x, a, b, m, sig):
        sig2 = sig**2
        x2 = x*x
        bx = np.where(x<0, 0.0, b*x*iv(0, x*m/sig2))
        gauss = a*np.exp(-0.5*x2/sig2)
        rice = bx*np.exp(-0.5*(x2+m*m)/(sig2))
        return gauss - rice
        #return (a-bx)*np.exp(-0.5*((x-m)/sig)**2)

    def function(self, ms, p):
        data = None
        for m in ms:
            if m.header["MAPTYPE"].strip() == 'PI':
               data1 = m.data
            elif m.header["MAPTYPE"].strip() == 'PI*':
               data2 = m.data 
            else:
               self.Error("sorry, PI and PI* are needed")
               return [], p
        rms = self.get_rms(data2)
        x, y1, y2 = self.histo(data1, data2, rms)
        p0 = (max(y1-y2), max(y1-y2)/len(x), 1.0, rms)
        xp = np.where(x < np.sqrt(2*np.pi)*rms)[0][-1]
        dy = y1-y2
        sigma = np.where(xp <= 0.0, 10.0, 1.0)
        try:
           result = curve_fit(self.ricegauss, x[:xp], dy[:xp], p0=p0, sigma=sigma)
        except:
           self.Error("try to scale map")
           return [], p
        l = 1
        while result[1] == [] or result[1][0][0] == np.inf:
           p0 = (max(y1-y2), max(y1-y2), 1.0, l*rms)
           try: result = curve_fit(self.ricegauss, x[:xp], dy[:xp], p0=p0)
           except: pass
           l = l/2.0
           if l < 1./1000: break
        popt = result[0]
        self.plot(x, y1, y2, popt, self.ricegauss)
        return [], p