filter_clean.py 5.63 KB
title = "Clean"
tip = "clean tool tip"

import time
import numpy as np
from guiqwt.config import _
from guidata.qt.QtGui import QMessageBox, qApp
from guidata.qt.QtCore import QEventLoop
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, FloatArrayItem, StringItem,
                                       ChoiceItem, FloatItem, DictItem,
                                       BoolItem)

# Bojan Nikolic <b.nikolic@mrao.cam.ac.uk>, <bojan@bnikolic.co.uk> 
# Initial version August 2010
#
# This file is part of pydeconv. This work is licensed under GNU GPL
# V2 (http://www.gnu.org/licenses/gpl.html)
"""
Clean based deconvolution, using numpy
"""

def overlapIndices(a1, a2, shiftx, shifty):
    if shiftx >=0:
        a1xbeg=shiftx
        a2xbeg=0
        a1xend=a1.shape[0]
        a2xend=a1.shape[0]-shiftx
    else:
        a1xbeg=0
        a2xbeg=-shiftx
        a1xend=a1.shape[0]+shiftx
        a2xend=a1.shape[0]

    if shifty >=0:
        a1ybeg=shifty
        a2ybeg=0
        a1yend=a1.shape[1]
        a2yend=a1.shape[1]-shifty
    else:
        a1ybeg=0
        a2ybeg=-shifty
        a1yend=a1.shape[1]+shifty
        a2yend=a1.shape[1]

    return (a1xbeg, a1xend, a1ybeg, a1yend), (a2xbeg, a2xend, a2ybeg, a2yend)
        


class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))

    def compute_app(self):
        class CleanParam(DataSet):
              gain  = FloatItem("Gain (loop):", default=0.1, min=0.0, max=1.0)
              thres = FloatItem("Threshold:", default=0.0)
              niter = IntItem("Iterations:", default=100000, min=1000)
              clean = BoolItem("CleanMap", default=True)
              
        param = CleanParam('Clean Filter', "Simple Clean Algorithm (Hogbom):<br>")
        name = title.replace(" ", "")
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param)

    def hogbom(self, dirty, psf, window, gain, thresh, niter):
        """
        Hogbom clean

        :param dirty: The dirty image, i.e., the image to be deconvolved

        :param psf: The point spread-function

        :param window: Regions where clean components are allowed. If
        True, thank all of the dirty image is assumed to be allowed for
        clean components

        :param gain: The "loop gain", i.e., the fraction of the brightest
        pixel that is removed in each iteration

        :param thresh: Cleaning stops when the maximum of the absolute
        deviation of the residual is less than this value

        :param niter: Maximum number of components to make if the
        threshold "thresh" is not hit
        """
        mask = np.isnan(dirty)
        dirty[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), dirty[~mask])
        comps = np.zeros(dirty.shape)
        res = np.array(dirty)
        if window is True:
           window = np.ones(dirty.shape, np.bool)
        xc = dirty.shape[0]/2
        yc = dirty.shape[1]/2
        t0 = time.time()
        self.progress = self.parent.imagewidget.Progress
        self.progress.showMessage(_("stop program with ^C"), 5000)
        t0 = time.time()
        for i in range(niter):
            qApp.processEvents(QEventLoop.AllEvents, 100000)
            if i % 100 == 0:
               dt = round(time.time()-t0, 2)
               self.progress.showMessage(_(str("Interation %d running %.2f sec" % (i, dt))))
            if self.parent.STOP:
               self.stop_text = "program stopped, no solution found"
               self.progress.showMessage(_(self.stop_text), 5000)
               self.parent.STOP = False
               #return [], []
               return comps , res
            mx, my = np.unravel_index(np.fabs(res[window]).argmax(), res.shape)
            mval = res[mx, my]*gain
            comps[mx, my] += mval
            a1o, a2o = overlapIndices(dirty, psf, mx-xc, my-yc)
            #comps[a1o[0]:a1o[1],a1o[2]:a1o[3]] += psf[a2o[0]:a2o[1],a2o[2]:a2o[3]]*mval
            res[a1o[0]:a1o[1],a1o[2]:a1o[3]] -= psf[a2o[0]:a2o[1],a2o[2]:a2o[3]]*mval
            resmax = np.fabs(res).max()
            if resmax < thresh:
               break
        #print 'iter=', i, 'res=', resmax, 'time=', time.time()-t0, 'sec'
        comps = np.where(mask, np.nan, comps)
        res = np.where(mask, np.nan, res)
        return comps , res

    def function(self, m, p):
        psf = self.gauss_psf(m)
        #m.data = psf
        #return m, p
        cmap, res = self.hogbom(m.data, psf, True, p.gain, p.thres, p.niter)
        if cmap == []:
           self.Error("Program stopped with ^C")
           #return [], p
        #m.data = cmap + res
        if p.clean: m.data = cmap #+ res
        else: m.data = res
        return m, p

    def gauss_psf(self, m):
        """Returns a gaussian function with the given parameters"""
        delt = abs(m.header['CDELT1'])
        if m.header['CTYPE3'].upper() == "LAMBDA":
           fghz = 1.22 * 0.299792458 / m.header['CRVAL3']
        else:
           fghz = m.header['CRVAL3']*1.e-9
        try: diam = m.header['DIAMETER']
        except: diam = 100.0
        clam = 0.299792458 / fghz
        hpbw = 180.0/np.pi * 1.22 * clam / diam 
        sigma = hpbw / 2.35482 / delt   # 2.35482 = 2.*sqrt(2.*ln(2.))
        size = int(m.data.shape[0]/2 + 0.5)
        Size = m.data.shape[0]-size
        sizey = int(m.data.shape[1]/2 + 0.5)
        Sizey = m.data.shape[1]-sizey
        x, y = np.mgrid[-size:Size, -sizey:Sizey]
        g = np.exp(-0.5*((x/sigma)**2 + (y/sigma)**2))
        return g