fit_gaussfit.py 7.5 KB
title = "Gaussfit"
tip = "2-Dim Gauss Fit (exit with <Space>)"
onein = -1

import numpy 
import nodfitting
import nodastro

from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, FloatArrayItem, StringItem,
                                       ChoiceItem, FloatItem, DictItem,
                                       BoolItem)
from guiqwt.config import _
from guiqwt.plot import CurveDialog
from guidata.qt.QtGui import QListView, QStandardItemModel, QStandardItem, QFont, QFileDialog
from guidata.qt.QtCore import QObject, SIGNAL, SLOT, QModelIndex, pyqtSlot

class NOD3ListView(QListView):

    @pyqtSlot("QModelIndex")
    def ItemClicked(self):
        name = title.replace(" ", "")+".out"
        filename = QFileDialog.getSaveFileName(self, "Save file", name, "Results (*.out)")
        if filename == "":
           return
        rows = self.model().rowCount()
        f = open(filename, 'w')
        for row in range(rows):
            index = self.model().index(row, 0)
            out = self.model().data(index).toString()
            f.write(out+"\n")
        f.close()

class NOD3_App():

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def compute_app(self):
        class GaussParam(DataSet):
              GRoI = IntItem("GaussROI:", default=6)
              Unit = ChoiceItem("Unit", (("Degree", "Degree"), ("Arcmin", "Arcmin"),
                                         ("Arcsec", "Arcsec")), default = "Arcmin")
              DelGauss = BoolItem("Remove Source", default=False)
              #Auto = BoolItem("Automatic Source Detection", default=False)
        param = GaussParam(title.replace(" ", ""), "2-dim Point Source Gaussian Fit <br><br> (press [SPACE] key to exit)<br>")
        self.name = title.replace(" ", "")
        param = self.parent.read_defaults(param)
        if not param.edit():
           return
        self.parent.write_defaults(param)
        self.Unit = param.Unit
        self.GRoI = param.GRoI
        self.parent.get_cursor_positions(self.name, lambda m, p, l, b: self.function(m, p, l, b),
                                         param, addmap=param.DelGauss)

    def create_sourcelist(self):
        # add source list box
        header = str("#Num      Amp        PixX       PixY       Long        Lat    HPBW(M)    HPBW(m)     RotAng")
        #self.sourcelist = QListView()
        self.sourcelist = NOD3ListView()
        self.sourcelist.setWindowTitle('NOD3 Source List (click into list to save it)')
        self.sourcelist.setMinimumSize(700, 10)
        self.model = QStandardItemModel(self.sourcelist)
        self.sourcelist.setModel(self.model)
        self.parent.LView = self.sourcelist
        QObject.connect(self.sourcelist, SIGNAL("clicked(QModelIndex)"),
                        self.sourcelist, SLOT("ItemClicked(QModelIndex)"))
        self.model.appendRow(QStandardItem(header))
        self.N = 0
        self.parent.N = True
        #print dir(self.sourcelist)

    def printer(self, h, gpar, err, p, l, b):
        if type(gpar) == type(None): return False
        rad = numpy.pi/180
        if self.Unit == "Degree":
           f = 2*numpy.sqrt(2.0*numpy.log(2.0))
           f1 = 1.0
        elif self.Unit == "Arcmin":
           f = 60 * 2*numpy.sqrt(2.0*numpy.log(2.0))
           f1 = 60.0
        elif self.Unit == "Arcsec":
           f = 3600 * 2*numpy.sqrt(2.0*numpy.log(2.0))
           f1 = 3600.0
        if not 'CDELT1' in h:
           h['CDELT1'] = -1.0
           h['CDELT2'] = 1.0
           h['CRPIX1'] = 1.0
           h['CRPIX2'] = 1.0
           h['CTYPE1'] = 'Pixel'
           h['CTYPE2'] = 'Pixel'
           h['CRVAL1'] = 1.0
           h['CRVAL2'] = 1.0
        sx = abs(h['CDELT1']) * f
        sy = abs(h['CDELT2']) * f
        amp = gpar[3]
        x = l + gpar[4]
        y = b + gpar[5]
        if sx*gpar[6] > sy*gpar[7]:
           hpbw_x = sx*gpar[6]
           hpbw_y = sy*gpar[7]
        else:
           hpbw_x = sx*gpar[7]
           hpbw_y = sy*gpar[6]
        rotang = (gpar[-1] / rad) % 360.0
        if rotang > 180.0: rotang -= 180.0
        d_amp = abs(err[3] - err[0]) #- p.GRoI*err[1] - p.GRoI*err[2]
        d_x = err[4]
        d_y = err[5]
        if sx*gpar[6] > sy*gpar[7]:
           d_hpbw_x = sx*sx*err[6]/f
           d_hpbw_y = sy*sy*err[7]/f
        else:
           d_hpbw_x = sx*sx*err[7]/f
           d_hpbw_y = sy*sy*err[8]/f
        d_rotang = 0.1*err[8] / rad
        if 'BMAJ' in h and 'BMIN' in h:
           hpbw = (h['BMIN'] + h['BMAJ'])/2.0
           dx1 = dy1 = f1*hpbw*0.75
           dx2 = dy2 = f1*hpbw*1.5
        elif 'CRVAL3' in h:
           fghz = h['CRVAL3']*1.e-9
           if 'DIAMETER' in h: diam = h['DIAMETER']
           else: diam = 100.0
           clam = 0.299792458 / fghz
           hpbw = 180.0/numpy.pi * 1.22 * clam / diam
           dx1 = dy1 = f1*hpbw*0.7
           dx2 = dy2 = f1*hpbw*1.5
        else:
           dx = f*abs(h['CDELT1'])/2
           dy = f*abs(h['CDELT2'])/2
           dx1 = dx*1.7
           dy1 = dy*1.7
           dx2 = p.GRoI*dx
           dy2 = p.GRoI*dy
        #if amp > 0.0 and (hpbw_x > dx1 and hpbw_x < dx2) and \
        #                 (hpbw_y > dy1 and hpbw_y < dy2):
        hpbw_x = sx*gpar[6]
        hpbw_y = sy*gpar[7]
        dx = f*abs(h['CDELT1'])/2
        dy = f*abs(h['CDELT2'])/2
        if abs(amp) > 0 and (hpbw_x > dx and hpbw_x < self.GRoI*dx) and \
                       (hpbw_y > dy and hpbw_y < self.GRoI*dy):
           self.N += 1
           l, b = self.parent.get_plot_coordinates(x, y)
           if type(self.mat) != type(None):
              l, b = self.nt.getCoord(l, b, self.mat)
           l = l % 360.0
           d_l = abs(h['CDELT1']) * err[4]
           d_b = abs(h['CDELT2']) * err[5]
           each = [[amp, x, y, l, b, hpbw_x, hpbw_y, rotang], \
                   [d_amp, d_x, d_y, d_l, d_b, d_hpbw_x, d_hpbw_y, d_rotang]]
           val = tuple([self.N] + each[0])
           err = tuple(each[1])
           source = str("%4d %10.5g %10.2f %10.2f %10.4f %10.4f %10.2f %10.2f %10.2f" % val)
           self.model.appendRow(QStandardItem(source))
           serror = str("  +/-  %10.5g %10.2f %10.2f %10.4f %10.4f %10.2g %10.2g %10.2f" % err)
           self.model.appendRow(QStandardItem(serror))
           self.model.appendRow(QStandardItem(130*"-"))
           self.sourcelist.scrollToBottom()
           self.sourcelist.show()
           return True
        else:
           self.sourcelist.show()

    def function(self, m, p, l, b):
        #m = ms[0]
        self.nt = nodastro.nodtrafo()
        if not hasattr(self.parent, 'N'):
           self.create_sourcelist()
        if m.header['CTYPE1'][-3:] == "DES":
           if 'CROTA2' in m.header:
              rotang = m.header['CROTA2']
           else:
              rotang = 0.0
           #self.mat = self.nt.rotmat(m.header['CRVAL1'], m.header['CRVAL2'], rotang)
           self.mat = None
        else:
           self.mat = None
        jmax, imax = m.data.shape
        i1 = max(0, int(l) - p.GRoI)
        i2 = min(imax, int(l) + p.GRoI)
        j1 = max(0, int(b) - p.GRoI)
        j2 = min(jmax, int(b) + p.GRoI)
        mdata = self.parent.nan_check(m.data, numpy.nanmin(m.data))
        data = mdata[j1:j2,i1:i2]
        gauss_params, err = nodfitting.fitgaussian(data, l-i1, b-j1)
        if self.printer(m.header, gauss_params, err, p, i1, j1):
           gauss_params_nobase = [0.0, 0.0, 0.0] + list(gauss_params[3:])
           gfit = nodfitting.gaussian(*gauss_params_nobase)
           m.data[j1:j2,i1:i2] -= gfit(*numpy.indices(data.shape))
        return