fit_autogaussfit.py 10.2 KB
title = "AutoGaussfit"
tip = "2-Dim Gauss Fit automatic run"
onein = True

import numpy 
import nodfitting
import nodastro
import time

from guidata.qt.QtGui import QMessageBox
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, FloatArrayItem, StringItem,
                                       ChoiceItem, FloatItem, DictItem,
                                       BoolItem)
from guiqwt.config import _
from guiqwt.plot import CurveDialog
from guidata.qt.QtGui import QSplitter, QListWidget
from guiqwt.tools import SelectPointTool

class NOD3_App():

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()
        #self.listbox = ListBox(self.parent)
        #self.parent.addWidget(self.listbox)

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))
    def compute_app(self, **args):
        class GaussParam(DataSet):
              GRoI = IntItem("GROI:", default=4, min=3)
              Iter = IntItem("Iter:", default=3)
              Clip = FloatItem("Clip:", default=0.0)
              Unit = ChoiceItem("Unit:", (("Degree", "Degree"), ("Arcmin", "Arcmin"), 
                                         ("Arcsec", "Arcsec")), default = "Arcmin")
              Sorted = ChoiceItem("Sort:", (("Amplitude", "Amplitude"), ("Longitude", "Longitude"),
                                            ("Latitude", "Latitude")), default = "Longitude")
              File = StringItem("Filename:", default="GaussFits")
              TP = BoolItem("Total Power", default=True)
              DelGauss = BoolItem("Remove Source", default=False)
        name = title.replace(" ", "")
        if args == {}:
           param = GaussParam(title.replace(" ", ""), "2-dim Automatic Point Source Gaussian Fit")
        else:
           param = self.parent.ScriptParameter(name, args)

        #self.parent.setup_KeyEvent(self.parent, "Stop")
        param = self.parent.read_defaults(param)
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein)

    def printer(self, h, gpar, err, p, l, b):
        if type(gpar) == type(None): return False
        rad = numpy.pi/180
        if p.Unit == "Degree":
           f = 2*numpy.sqrt(2.0*numpy.log(2.0))
           f1 = 1.0
        elif p.Unit == "Arcmin":
           f = 60 * 2*numpy.sqrt(2.0*numpy.log(2.0))
           f1 = 60.0
        elif p.Unit == "Arcsec":
           f = 3600 * 2*numpy.sqrt(2.0*numpy.log(2.0))
           f1 = 3600.0
        if not 'CDELT1' in h:
           h['CDELT1'] = -1.0
           h['CDELT2'] = 1.0
           h['CRPIX1'] = 1.0
           h['CRPIX2'] = 1.0
           h['CTYPE1'] = 'Pixel'
           h['CTYPE2'] = 'Pixel'
           h['CRVAL1'] = 1.0
           h['CRVAL2'] = 1.0
        #print ">>>>>>>>>>>>>>>", self.parent.key_stop
        sx = abs(h['CDELT1']) * f
        sy = abs(h['CDELT2']) * f
        amp = gpar[3]
        x = l + gpar[4]
        y = b + gpar[5]
        if sx*gpar[6] > sy*gpar[7]:
           hpbw_x = sx*gpar[6]
           hpbw_y = sy*gpar[7]
        else:
           hpbw_x = sx*gpar[7]
           hpbw_y = sy*gpar[6]
        rotang = (gpar[8] / rad) % 360.0
        d_amp = abs(err[3] - err[0]) #- p.GRoI*err[1] - p.GRoI*err[2]
        d_x = err[4]
        d_y = err[5]
        if sx*gpar[6] > sy*gpar[7]:
           d_hpbw_x = sx*sx*err[6]/f
           d_hpbw_y = sy*sy*err[7]/f
        else:
           d_hpbw_x = sx*sx*err[7]/f
           d_hpbw_y = sy*sy*err[8]/f
        d_rotang = 0.1*err[8] / rad
        if rotang > 180.0: rotang -= 180.0
        if 'BMAJ' in h and 'BMIN' in h:
           hpbw = (h['BMIN'] + h['BMAJ'])/2.0 
           dx1 = dy1 = f1*hpbw*0.75
           dx2 = dy2 = f1*hpbw*1.5
        elif 'CRVAL3' in h:
           fghz = h['CRVAL3']*1.e-9
           if 'DIAMETER' in h: diam = h['DIAMETER']
           else: diam = 100.0
           clam = 0.299792458 / fghz
           hpbw = 180.0/numpy.pi * 1.22 * clam / diam
           dx1 = dy1 = f1*hpbw*0.7
           dx2 = dy2 = f1*hpbw*1.5
        else:
           dx = f*abs(h['CDELT1'])/2
           dy = f*abs(h['CDELT2'])/2
           dx1 = dx*1.7
           dy1 = dy*1.7
           dx2 = p.GRoI*dx
           dy2 = p.GRoI*dy
        if not p.TP: Amp = abs(amp)
        else: Amp = amp
        if Amp > p.Clip and (hpbw_x > dx1 and hpbw_x < dx2) and \
                       (hpbw_y > dy1 and hpbw_y < dy2):
           self.N += 1
           l, b = self.parent.get_plot_coordinates(x, y)
           if type(self.mat) != type(None):
              l, b = self.nt.getCoord(l, b, self.mat)
           l = l % 360.0
           d_l = abs(h['CDELT1']) * err[4]
           d_b = abs(h['CDELT2']) * err[5]
           each = [[amp, x, y, l, b, hpbw_x, hpbw_y, rotang], \
                   [d_amp, d_x, d_y, d_l, d_b, d_hpbw_x, d_hpbw_y, d_rotang]]
           if p.Sorted == "Amplitude":
              self.sort[amp] = each
              self.reverse = True
           elif p.Sorted == "Longitude":
              self.sort[l] = each
              self.reverse = False
           elif p.Sorted == "Latitude":
              self.sort[b] = each
              self.reverse = False
           #print txt
           #print str("%d A=%g  X=%g   Y=%g   HPBWx=%g   HPBWy=%g   RotAng=%g" % (self.N, amp, x, y, hpbw_x, hpbw_y, rotang))
           return True

    def sortedDictValues(self, adict):
        items = adict.items()
        items.sort()
        if self.reverse: items.reverse()
        return [value for key, value in items]

    def function(self, m, p):
        self.parent.STOP = False
        self.sort = {}
        self.N = 0
        Data = m.data.copy()
        Data = self.parent.nan_check(Data, numpy.nanmin(Data))
        ydim, xdim = Data.shape
        if p.Clip > 0.0:
           Dmax = p.Clip + numpy.fabs(Data).min()
        else:
           #Dmax = 2*numpy.median(Data)
           Dmax = numpy.nanmin(Data)
        t = time.time()
        if m.header['CTYPE3'] == "LAMBDA":
           fghz = 300000000./m.header['CRVAL3']
           m.header['CRVAL3'] = fghz
           m.header['CTYPE3'] = "FREQ"
           fghz = fghz*1.e-9
        #if hasattr(p, 'GRoI') and 'BMAJ' in m.header and 'BMIN' in m.header:
        #   hpbw = numpy.sqrt(m.header['BMIN']**2 + m.header['BMAJ']**2)
        #   try:
        #      p.GRoI = int(1.22*hpbw / abs(m.header['CDELT1']))
        #   except: 
        #      self.Error("CDELT1 not set in header")
        #      return [], p 
        #elif 'CRVAL3' in m.header:
        #   fghz = m.header['CRVAL3']*1.e-9
        # if 'DIAMETER' in m.header: diam = m.header['DIAMETER']
        #   else: diam = 100.0
        #   clam = 0.299792458 / fghz
        #   hpbw = 180.0/numpy.pi * 1.22 * clam / diam
        #   p.GRoI = int(1.22*hpbw / abs(m.header['CDELT1']))
        #else:
        #   p.GRoI = 4
        ##p.GRoI = max(4, p.GRoI)
        ##if p.GRoI > 25: p.GRoI = 4

        # check for descriptive system
        self.nt = nodastro.nodtrafo()
        if m.header['CTYPE1'][-3:] == "DES":
           if 'CROTA2' in m.header:
              rotang = m.header['CROTA2']
           else:
              rotang = 0.0
           #self.mat = self.nt.rotmat(m.header['CRVAL1'], m.header['CRVAL2'], rotang)
           self.mat = None
        else:
           self.mat = None
        # start source searching
        progress = self.parent.imagewidget.Progress
        it_old = -1
        for it in range(p.Iter):
            GRoI = max(3.0, p.GRoI*(0.9**it))
            if it != it_old:
               it_old = it
               progress.showMessage(_("Number of sources: %d, Interation %d " % (self.N, it+1)), 5)

            Data = m.data.copy()
            Data = self.parent.nan_check(Data, numpy.nanmin(Data))
            vmax = Data.max()
            vmin = Data.min()
            while vmax > vmin:
                if self.parent.STOP:
                   progress.showMessage(_("stopped by ^C"), 5000)
                   break
                b, l = numpy.unravel_index(Data.argmax(), Data.shape)
                i1 = max(0, int(l) - GRoI)
                i2 = min(xdim, int(l) + GRoI)
                j1 = max(0, int(b) - GRoI)
                j2 = min(ydim, int(b) + GRoI)
                data = m.data[j1:j2,i1:i2]
                gauss_params, err = nodfitting.fitgaussian(data, l-i1, b-j1, fast=True)
                #gauss_params, err = nodfitting.fitgaussian(data, p.GRoI, p.GRoI, fast=True)
                vmax = Data.max()
                if self.printer(m.header, gauss_params, err, p, i1, j1):
                   progress.showMessage(_("Number of sources: %d, Interation %d " % (self.N, it+1)), 5)
                   gauss_params_nobase = [0.0, 0.0, 0.0] + list(gauss_params[3:])
                   gfit = nodfitting.gaussian(*gauss_params_nobase)
                   Data[j1:j2,i1:i2] = vmin
                   m.data[j1:j2,i1:i2] -= gfit(*numpy.indices(data.shape))
                   if Data.max() > Dmax:
                      vmax = Data.max()
                   else:
                      vmin = vmax
                else:
                   progress.showMessage(_("Number of sources: %d, Interation %d " % (self.N, it+1)), 5)
                   Data[j1:j2,i1:i2] = vmin
                   vmax = Data.max()
        progress.showMessage(_("Number of sources: %d, Interation %d " % (self.N, it+1)), 5000)
        n = 0
        try:
           All = self.sortedDictValues(self.sort) 
        except:
           All = []
        try:
           txt = str("#Num        Amp       PixX       PixY       Long        Lat    HPBW(M)    HPBW(m)     RotAng")
           f = open(self.parent.cwd+p.File, "w")
           f.write(txt+"\n")
        except:
           self.Error("Sorry, permission denied for writing\n /tmp is used")
           f = open("/tmp/"+p.File, "w")
           f.write(txt+"\n")
        for each in All:
            n += 1
            val = tuple([n] + each[0])
            err = tuple(each[1])
            txt = str("%4d %10.5g %10.2f %10.2f %10.4f %10.4f %10.2f %10.2f %10.2f" % val)
            f.write(txt+"\n")
            txt = str("     %10.2f %10.2f %10.2f %10.4f %10.4f %10.2f %10.2f %10.2f" % err)
            f.write(txt+"\n")
            f.write("#"+91*"-"+"\n")
        f.close()
        print self.N, "sources found in", time.time()-t, "sec"
        if p.DelGauss: return m, p
        else: return [], p