base_PEMrestore.py 12.1 KB
title = "PEM-Restore"
tip = "restore multi-horn maps scanned in Azimuth"
onein = False

import copy as cp
import numpy as np
import scipy.signal as sps
from scipy.ndimage import gaussian_filter1d
from scipy import stats
#from scipy.interpolate import interp1d

from guidata.qt.QtGui import QMessageBox
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (IntItem, StringItem, ChoiceItem, FloatItem, BoolItem)
from guiqwt.config import _

from nodmath import extract, nan_interpol

def factorial(n):
    return reduce(lambda x,y:x*y,[1]+range(1,n+1))

def combinations(horns):
    comb = []
    for j in range(len(horns)-1):
        for i in range(j+1, len(horns)):
            comb.append([horns[j], horns[i]])
    return comb

def median(x):
    xl = list(x)
    xl.sort()
    return xl[len(xl)/3]

def nint(x):
    if x > 0: return int(x+0.5)
    else: return int(x-0.5)

class NOD3_App:

    def __init__(self, parent):
        self.parent = parent
        self.parent.activateWindow()

    def Error(self, msg):
        QMessageBox.critical(self.parent.parent(), title,
                              _(u"Error:")+"\n%s" % str(msg))

    def compute_app(self, **args):
        class FuncParam(DataSet):
            adjust  = FloatItem('Adjust(%)', default=10, max=50, min=0)
            flatten  = BoolItem("Flatten", default=True)
        name = title.replace(" ", "")
        if args == {}:
           param = FuncParam(_(title), "Correct offset via adjust")
        else:
           param = self.parent.ScriptParameter(name, args)

        # if no parameter needed set param to None. activate next line
        #param = None
        self.parent.compute_11(name, lambda m, p: self.function(m, p), param, onein, MType=True) 

    def autocal(self, data1, data2, ix):
        return data1, data2
        #d1 = 1*np.ravel(data1)
        #d2 = 1*np.ravel(data2)
        d1 = 1*np.ravel(data1[:,ix:])
        d2 = 1*np.ravel(data2[:,:-ix])
        d1.sort()
        d2.sort()
        l = len(d1)/2
        d = np.array([d1[l:], np.ones(len(d1[l:]))])
        x = np.linalg.lstsq(d.T, d2[l:])
        b, a = x[0]
        return a + b*data1, data2

    def autocal1(self, data1, data2, ix, scale=0.0):
        if scale == 0.0: s = 1.0
        else: s = scale
        d1 = list(np.ravel(data1[:,ix:]))
        d2 = list(np.ravel(s*data2[:,:-ix]))
        d1.sort()
        d2.sort()
        nl = max(10, len(d1)/10)
        if scale == 0.0:
           a1, a0 = np.polyfit(d2[-nl:], d1[-nl:], 1)
        else:
           a1 = scale
           aa, a0 = np.polyfit(d2[-nl:], d1[-nl:], 1)
        return data1, a0 + a1*data2
        
    def yintpn(self, a, b, c, d, y):
        if y == 0.0: 
           return b
        e = b-a
        f = b-c
        g = e+f+f+d-c
        g = g*y-e-f-g
        g = g*y-a+c
        return 0.5*g*y+b

    def restore(self, d1, d2, dx):
        dx2 = int(dx/2+0.5)
        wsum = 0.0
        weather = 0.0*d1
        for i in range(dx2, len(d1)-dx2):
            j = len(d1) - i - 1
            wsum += (d2[i+dx2] - d1[i-dx2])/dx
            weather[i] = wsum
        return weather

    def average(self, d1, d2, dx):
        rows, cols = d1.shape
        ic = nint(dx/2-0.5)
        for row in range(rows):
            for col in range(cols-ic):
                d1[row][col] = (d1[row][col] + d2[row][col+ic])/2.0
        return d1

    def adjust(self, m, proz=0.1):
        rows, cols = m.data.shape
        if not 'SCANDIR' in m.header:
        #if not 'SCANDIR' in m.header:
           self.Error("sorry: keyword 'SCANDIR' not found in FITS header")
           return [], None
        if m.header['SCANDIR'] in ("LON", "ALON", "ULON", "GLON", "RA"):
           l = int(cols*proz) - int(cols*proz-1)%2
           for row in range(rows):
               x = m.data[row]
               a = x[:l]
               b = x[-l:]
               amask = np.isnan(a)
               bmask = np.isnan(b)
               if False in amask and False in bmask:
                  am = median(a[-amask])
                  bm = median(b[-bmask])
                  #ax = list(a).index(am)
                  #bx = len(x[:-l]) + list(b).index(bm)
                  ax = float(l)/2.0
                  bx = len(x)-ax
                  s = (am-bm)/(ax-bx)
                  o = am - s*ax
                  m.data[row] -= o + s*np.arange(cols)
                  #m.data[row] -= (am+bm)/2
        else:
           l = int(rows*proz) - int(rows*proz-1)%2
           for col in range(cols):
               x = m.data[:,col]
               a = x[:l]
               b = x[-l:]
               amask = np.isnan(a)
               bmask = np.isnan(b)
               if False in amask and False in bmask:
                  am = median(a[-amask])
                  bm = median(b[-bmask])
                  #ax = list(a).index(am)
                  #bx = len(x[:-l]) + list(b).index(bm)
                  ax = float(l)/2.0
                  bx = len(x)-ax
                  s = (am-bm)/(ax-bx)
                  o = am - s*ax
                  m.data[:,col] -= o + s*np.arange(rows)
                  #m.data[:,col] -= (am+bm)/2
        return m.data

    def presse(self, Data, itera, Order, w):
        data = 1*Data
        rows, cols = data.shape
        for i in range(itera):
            diff = data - gaussian_filter1d(data, 1.5, axis=0, order=0)
            mask = np.isnan(diff)
            med = np.median(diff[~mask])
            sig = np.std(diff[~mask])
            x = np.arange(cols)
            for row in range(rows):
                mask = (diff[row] > med-w*sig) & (diff[row] < med+w*sig)
                if len(x[mask]) <= Order: break
                poly = np.poly1d(np.polyfit(x[mask], diff[row][mask], Order))
                data[row] -= poly(x)
            w *= 0.75
        return data

    def shiftr(self, m):
        #dlon = m.header['PATLONG']/m.header['CDELT1']
        dlat = m.header['PATLAT']/m.header['CDELT2']
        if abs(dlat) < 0.001: return m
        rows, cols = m.data.shape
        off = int(dlat)
        dx = dlat-off
        for col in range(cols):
            data = 1*m.data[:,col]
            for row in range(rows):
                if row + off < 0 or row + off >= rows:
                   x = np.nan
                else:
                   #r = max(0, min(rows-1, row + off))
                   r = row + off
                   i = max(0, r-1)
                   j = r
                   k = min(rows-2, r+1)
                   l = min(rows-1, r+2)
                   x = self.yintpn(data[i], data[j], data[k], data[l], dx)
                m.data[row, col] = x
        return m

    def shiftc(self, m):
        dlon = m.header['PATLONG']/m.header['CDELT1']
        #dlat = m.header['PATLAT']/m.header['CDELT2']
        if abs(dlon) < 0.001: return m
        rows, cols = m.data.shape
        off = int(dlon)
        dx = dlon-off
        for row in range(rows):
            data = 1*m.data[row]
            for col in range(cols):
                if col + off < 0 or col + off >= cols:
                   x = np.nan
                else:
                   #c = max(0, min(cols-1, col + off))
                   c = col + off
                   i = max(0, c-1)
                   j = c
                   k = min(cols-2, c+1)
                   l = min(cols-1, c+2)
                   x = self.yintpn(data[i], data[j], data[k], data[l], dx)
                m.data[row, col] = x
        return m

    def sortRX2(self, ms):
        horns = []
        for m in ms:
            horns.append(m.header['RXHORN'])        
        return np.unique(horns)

    def sortRX(self, ms, horns):
        nmaps = len(ms)
        horn1 = []
        horn2 = []
        mtyp1 = []
        mtyp2 = []
        for M in ms:
            try:
               m = cp.deepcopy(M)
            except:
               m = cp.copy(M)
               m.data = 1*M.data
               m.header = M.header.copy()
            if m.header['RXHORN'] == horns[0]:
               horn1.append(m)
               mtyp1.append(m.header['MAPTYPE'])
            elif m.header['RXHORN'] == horns[1]:
               horn2.append(m)
               mtyp2.append(m.header['MAPTYPE'])
            #else:
            #   return [], []
        h1 = []
        h2 = []
        mtype = cp.copy(mtyp1)
        mtype.sort()
        #for mt in mtype:
        for mt in mtyp1:
            try:
               i = mtyp1.index(mt)
               j = mtyp2.index(mt)
               h1.append(horn1[i]) 
               h2.append(horn2[j]) 
            except:
               pass
        return h1, h2       

    def function(self, ms, p):
        horns = self.sortRX2(ms)
        combs = combinations(horns)
        mout = []
        for k in range(len(combs)):
            h1, h2 = self.sortRX(ms, combs[k])
            if h2 == []:
               self.Error("sorry, exactly two horns with same channels are accepted")
               return [], p
            for n in range(len(h1)):
                m1 = h1[n]
                m2 = h2[n]
                if p.adjust > 0: 
                   m1.data = self.adjust(m1, proz=p.adjust/100.0)
                   m2.data = self.adjust(m2, proz=p.adjust/100.0)
                mask1, m1.data = nan_interpol(m1.data)
                mask2, m2.data = nan_interpol(m2.data)
                dxx = m1.header['PATLONG'] / m1.header['CDELT1']
                dx = (m2.header['PATLONG'] - m1.header['PATLONG']) / m2.header['CDELT1']
                if abs(dx) < 1:
                   self.Error(str("no difference in horn offset: %d pixel" % dx))
                   return [], p
                if m1.header['MAPTYPE'] in ('iU', 'iQ'):
                   if dx < 0:
                      m1 = h2[n]
                      m2 = h1[n]
                      dx = (m2.header['PATLONG'] - m1.header['PATLONG']) / m2.header['CDELT1']
                   m1 = self.shiftc(m1)
                   m2 = self.shiftc(m2)
                   ix = nint(dx+0.5)
                   ix2 = nint(dx/2+0.5)
                   ddx = dx-nint(dx)
                   rows, cols = m1.data.shape
                   m1.data = m1.data[:,ix-1:cols]
                   m2.data = m1.data[:,:cols-ix+1]
                   m1.data = stats.nanmean(np.array([m1.data, m2.data]), axis=0)
                   m1.header['CRPIX1'] -= dx
                   m1.header['PATLONG'] = 0.0
                   m1.header['SIDPIX'] = (dx, 'Offset pixel to sidmap')
                else:
                   if m2.header['PATLONG'] < m1.header['PATLONG']:
                      m1 = h2[n]
                      m2 = h1[n]
                      dx = (m2.header['PATLONG'] - m1.header['PATLONG']) / m2.header['CDELT1']
                   #if abs(dx) > 60.0: dx = abs(dx)/3600.0 # suppose unit is arcsec
                   ix = nint(dx+0.5)
                   ix2 = nint(dx/2+0.5)
                   ddx = dx-nint(dx)
                   data1, data2 = self.autocal(m1.data, m2.data, ix2) #, scale=p.BeamScale)
                   rows, cols = data1.shape
                   for row in range(rows):
                       #diff = self.restore(wdata1[row], wdata2[row], dx)
                       diff = self.restore(data1[row], data2[row], dx)
                       data1[row] -= diff
                       data2[row] -= diff
                   #m1.header['CRPIX1'] -= dx
                   m1.header['CRPIX1'] += m1.header['PATLONG']/m1.header['CDELT1']
                   m1.header['PATLONG'] = 0.0
                   m1.header['SIDPIX'] = (dx, 'Offset pixel to sidmap')
                   #m2.header['CRPIX1'] -= dx/2 
                   m2.header['PATLONG'] = 0.0
                   m2.header['SIDPIX'] = ( dx, 'Offset pixel to sidmap')
                   m1.data = data1[:,:cols-ix+1]
                   m2.data = data2[:,ix2-1:cols-ix2+1]
                   m1.data = self.average(m1.data, m2.data, dx)
                   if p.adjust > 0: m1.data = self.adjust(m1, proz=p.adjust/100.0)
                   if p.flatten: m1.data = self.presse(m1.data, 2, 4.0, 1.0)
                self.parent.SidOut = True
                self.parent.ParOut = True
                amap = m1.parmap.ravel()
                m1.header["PARANG"] = (amap[len(amap)/2], 'Mean parallactic angle')
                m1.header['NAXIS1'] = m1.data.shape[1]  # x-axis has changed
                mout.append(m1)
        return mout, p