From 063dc6e2c16d2628d6cd8909391850670be6b692 Mon Sep 17 00:00:00 2001 From: Nadia Dencheva Date: Fri, 13 Apr 2018 11:13:32 -0400 Subject: [PATCH] Use astropy in mktrace --- lib/stistools/mktrace.py | 224 ++++++++++++++++++++------------------- 1 file changed, 114 insertions(+), 110 deletions(-) diff --git a/lib/stistools/mktrace.py b/lib/stistools/mktrace.py index 34984db..6327cff 100644 --- a/lib/stistools/mktrace.py +++ b/lib/stistools/mktrace.py @@ -26,31 +26,28 @@ - Python version: Nadia Dencheva """ -from __future__ import division, print_function # confidence high +from __future__ import division, print_function -import numpy as N -from astropy.io import fits as pyfits import os.path +import numpy as np +from astropy.io import fits +from astropy.modeling import models, fitting from scipy import signal from scipy import ndimage as ni -from stsci.tools import gfit, linefit +from stsci.tools import linefit from stsci.tools import fileutil as fu -__version__ = '1.2' -__vdate__ = '2016-03-02' +__version__ = '3.0.0' +__vdate__ = '2018-04-20' def mktrace(fname, tracecen=0.0, weights=None): """ Refine a stis spectroscopic trace. """ - #import time - - #start=time.time() - try: - hdulist = pyfits.open(fname) + hdulist = fits.open(fname) except IOError: print("\nUNABLE TO OPEN FITS FILE: %s \n" % fname) return @@ -63,45 +60,44 @@ def mktrace(fname, tracecen=0.0, weights=None): kwinfo = getKWInfo(hdr0, hdr1) if kwinfo['instrument'] != 'STIS': print("This trace tool works only on STIS spectroscopic observations.\n") - print("Not processing file %s.\n" %fname) + print("Not processing file {}.\n".format(fname)) return sizex, sizey = data.shape - if weights == None: - wei = N.ones(sizey) + if weights is None: + wei = np.ones(sizey) else: if not iterable(weights) or not iterable(weights[0]): print("Weights must be a list of tuples, for example:\n") - print("weights=[(23,45),(300,670)] \n") + print("weights=[(23, 45),(300, 670)] \n") return - wei = N.zeros(sizey) + wei = np.zeros(sizey) - for i in N.arange(len(weights)): - for j in N.arange(weights[i][0], weights[i][1]): + for i in np.arange(len(weights)): + for j in np.arange(weights[i][0], weights[i][1]): wei[j] = 1 - #wind are weights indices in the image frame which may be a subarray - wind = N.nonzero(wei)[0] + # wind are weights indices in the image frame which may be a subarray + wind = np.nonzero(wei)[0] tr = Trace(fname, kwinfo) - a2center, trace1024 = tr.generateTrace(data,kwinfo, tracecen=tracecen, wind=wind) - #compute the full frame a2center + a2center, trace1024 = tr.generateTrace(data, kwinfo, tracecen=tracecen, wind=wind) + # compute the full frame a2center ffa2center = a2center*kwinfo['binaxis2'] tr_ind, a2disp_ind = tr.getTraceInd(ffa2center) - #print 'tr_ind', tr_ind tr2 = tr.readTrace(tr_ind) if tr_ind != a2disp_ind[0]: - tr1 = tr.readTrace(tr_ind -1) + tr1 = tr.readTrace(tr_ind - 1) interp_trace = trace_interp(tr1, tr2, ffa2center) else: interp_trace = tr2 - #convert the weights array into full frame - ind = N.nonzero(wei)[0] * kwinfo['binaxis1'] - w = N.zeros(1024) + # convert the weights array into full frame + ind = np.nonzero(wei)[0] * kwinfo['binaxis1'] + w = np.zeros(1024) w[ind] = 1 - X = N.arange(1024).astype(N.float) + X = np.arange(1024).astype(np.float) sparams = linefit.linefit(X, trace1024, weights=w) rparams = linefit.linefit(X, interp_trace, weights=w) sciline = sparams[0] + sparams[1] * X @@ -109,8 +105,8 @@ def mktrace(fname, tracecen=0.0, weights=None): deltaline = sciline - refline - #create a complete trace similar to a row in a _1dt file - #used only for debugging + # create a complete trace similar to a row in a _1dt file + # used only for debugging tr._a2displ = trace1024 tr._a1center = tr1['a1center'] tr._a2center = a2center @@ -118,12 +114,13 @@ def mktrace(fname, tracecen=0.0, weights=None): tr._pedigree = tr1['pedigree'] tr._snr_thresh = tr1['snr_thresh'] - tr.writeTrace(fname, sciline, refline, interp_trace, trace1024, tr_ind, a2disp_ind) + tr.writeTrace(fname, sciline, refline, interp_trace, + trace1024, tr_ind, a2disp_ind) - #print 'time', time.time()-start - #the minus sign is for consistency withthe way x2d reports the rotation - print("Traces were rotated by %f degrees \n" % (-(sparams[1]-rparams[1])*180 / N.pi)) - print('trace is centered on row %f' % tr._a2center) + # the minus sign is for consistency withthe way x2d reports the rotation + print("Traces were rotated by {0} degrees \n".format( + - (sparams[1] - rparams[1]) * 180 / np.pi)) + print('Trace is centered on row {}'.format(tr._a2center)) return tr @@ -135,35 +132,35 @@ def iterable(v): return False -def interp(y,n): +def interp(y, n): """ - Given a 1D array of size m, interpolates it to a size n (m a2center)[0][0] + ind[0][0] + ind = np.nonzero(a2disp_ind) + i = np.nonzero(self.sptrctab[ind].field('A2CENTER') > a2center)[0][0] + ind[0][0] return i, a2disp_ind @@ -269,7 +263,6 @@ def readTrace(self, tr_ind): return tr - def writeTrace(self, fname, sciline, refline, interp_trace, trace1024, tr_ind, a2disp_ind): """ The 'writeTrace' method performs the following steps: @@ -290,105 +283,111 @@ def writeTrace(self, fname, sciline, refline, interp_trace, trace1024, tr_ind, a infile = fname.split('.') newname = infile[0] + '_1dt.' + infile[1] - #refine all traces for this CENWAVE, OPT_ELEM + # refine all traces for this CENWAVE, OPT_ELEM fu.copyFile(fpath, newname) - hdulist = pyfits.open(newname, mode='update') + hdulist = fits.open(newname, mode='update') tab = hdulist[1].data - ind = N.nonzero(a2disp_ind)[0] - for i in N.arange(ind[0], ind[-1]+1): - tab[i].setfield('A2DISPL', tab[i].field('A2DISPL') + (sciline-refline)) + ind = np.nonzero(a2disp_ind)[0] + for i in np.arange(ind[0], ind[-1] + 1): + tab[i].setfield('A2DISPL', tab[i].field('A2DISPL') + (sciline - refline)) if 'DEGPERYR' in tab.names: - for i in N.arange(ind[0], ind[-1]+1): + for i in np.arange(ind[0], ind[-1] + 1): tab[i].setfield('DEGPERYR', 0.0) hdulist.flush() hdulist.close() - #update SPTRCTAB keyword in the science file primary header - hdulist = pyfits.open(fname, mode='update') + # update SPTRCTAB keyword in the science file primary header + hdulist = fits.open(fname, mode='update') hdr0 = hdulist[0].header hdr0['SPTRCTAB'] = newname hdulist.close() - #write out the fit to the interpolated trace ('_interpfit' file) - refhdu = pyfits.PrimaryHDU(refline) - refname=infile[0] + '_1dt_interpfit.' + infile[1] + # write out the fit to the interpolated trace ('_interpfit' file) + refhdu = fits.PrimaryHDU(refline) + refname = infile[0] + '_1dt_interpfit.' + infile[1] if os.path.exists(refname): os.remove(refname) refhdu.writeto(refname) - #write out the interpolated trace ('_interp' file) - inthdu = pyfits.PrimaryHDU(interp_trace) - intname=infile[0] + '_1dt_interp.' + infile[1] + # write out the interpolated trace ('_interp' file) + inthdu = fits.PrimaryHDU(interp_trace) + intname = infile[0] + '_1dt_interp.' + infile[1] if os.path.exists(intname): os.remove(intname) inthdu.writeto(intname) - #write out the the fit to the science trace ('_scifit' file) - scihdu = pyfits.PrimaryHDU(sciline) + # write out the the fit to the science trace ('_scifit' file) + scihdu = fits.PrimaryHDU(sciline) sciname = infile[0] + '_1dt_scifit.' + infile[1] if os.path.exists(sciname): os.unlink(sciname) scihdu.writeto(sciname) - #write out the science trace ('_sci' file) - trhdu = pyfits.PrimaryHDU(trace1024) + # write out the science trace ('_sci' file) + trhdu = fits.PrimaryHDU(trace1024) trname = infile[0] + '_1dt_sci.' + infile[1] if os.path.exists(trname): os.unlink(trname) trhdu.writeto(trname) - def generateTrace(self, data, kwinfo, tracecen=0.0, wind=None): """ Generates a trace from a science file. """ - if kwinfo['sizaxis2'] != None and kwinfo['sizaxis2'] < 1023: + if kwinfo['sizaxis2'] is not None and kwinfo['sizaxis2'] < 1023: subarray = True - else: subarray = False + else: + subarray = False if tracecen == 0: if subarray: - _tracecen = kwinfo['sizaxis2']/2.0 + _tracecen = kwinfo['sizaxis2'] / 2.0 else: _tracecen = kwinfo['crpix2'] else: _tracecen = tracecen - sizex,sizey = data.shape + sizex, sizey = data.shape subim_size = 40 - y1 = int(_tracecen - subim_size/2.) - y2 = int(_tracecen + subim_size/2.) - if y1 < 0: y1 = 0 - if y2 > (sizex -1): y2 = sizex - 1 - specimage = data[y1:y2+1,:] + y1 = int(_tracecen - subim_size / 2.) + y2 = int(_tracecen + subim_size / 2.) + if y1 < 0: + y1 = 0 + if y2 > (sizex - 1): + y2 = sizex - 1 + specimage = data[y1: y2 + 1, :] smoytrace = self.gFitTrace(specimage, y1, y2) - yshift = int(N.median(smoytrace) - 20) + yshift = int(np.median(smoytrace) - 20) y1 = y1 + yshift y2 = y2 + yshift - if (y1 < 0): y1 = 0 - if y2 > sizex: y2 = sizex - specimage = data[y1:y2+1,:] + if (y1 < 0): + y1 = 0 + if y2 > sizex: + y2 = sizex + specimage = data[y1: y2 + 1, :] smoytrace = self.gFitTrace(specimage, y1, y2) - med11smoytrace = ni.median_filter(smoytrace,11) + med11smoytrace = ni.median_filter(smoytrace, 11) med11smoytrace[0] = med11smoytrace[2] diffmed = abs(smoytrace - med11smoytrace) - tolerence = 3 * N.median(abs(smoytrace[wind] - med11smoytrace[wind])) - if tolerence < 0.1: tolerence = 0.1 - badpoint = N.where(diffmed > tolerence)[0] + tolerence = 3 * np.median(abs(smoytrace[wind] - med11smoytrace[wind])) + if tolerence < 0.1: + tolerence = 0.1 + badpoint = np.where(diffmed > tolerence)[0] if len(badpoint) != 0: - N.put(smoytrace, badpoint, med11smoytrace[badpoint]) + np.put(smoytrace, badpoint, med11smoytrace[badpoint]) - #convolve with a gaussian to smooth it + # Convolve with a gaussian to smooth it. fwhm = 10. - sigma = fwhm/2.355 + sigma = fwhm / 2.355 gaussconvxsmoytrace = ni.gaussian_filter1d(smoytrace, sigma) - #compute the trace center as the median of the pixels with nonzero weights - tracecen = N.median(gaussconvxsmoytrace[wind]) + # Compute the trace center as the median of the pixels + # with nonzero weights. + tracecen = np.median(gaussconvxsmoytrace[wind]) gaussconvxsmoytrace = gaussconvxsmoytrace - tracecen - trace1024 = interp(gaussconvxsmoytrace,1024) * kwinfo['binaxis2'] - tracecen = tracecen + y1 +1.0 + trace1024 = interp(gaussconvxsmoytrace, 1024) * kwinfo['binaxis2'] + tracecen = tracecen + y1 + 1.0 if subarray: tracecen = tracecen - kwinfo['ltv2'] self.trace1024 = trace1024 @@ -399,15 +398,20 @@ def gFitTrace(self, specimage, y1, y2): Fit a gaussian to each column of an image. """ - sizex,sizey = specimage.shape - smoytrace = N.zeros(sizey).astype(N.float) + sizex, sizey = specimage.shape + smoytrace = np.zeros(sizey).astype(np.float) boxcar_kernel = signal.boxcar(3) / 3.0 - - for c in N.arange(sizey): - col = specimage[:,c] - col = col - N.median(col) - smcol = ni.convolve(col, boxcar_kernel).astype(N.float) - fit = gfit.gfit1d(smcol, quiet=1, maxiter=15) - smoytrace[c] = fit.params[1] - - return N.array(smoytrace) + fitter = fitting.LevMarLSQFitter() + + for c in np.arange(sizey): + col = specimage[:, c] + col = col - np.median(col) + smcol = ni.convolve(col, boxcar_kernel).astype(np.float) + #fit = gfit.gfit1d(smcol, quiet=1, maxiter=15) + x = np.arange(len(smcol)).astype(np.float) + gauss = models.Gaussian1D(amplitude=smcol.max(), mean=x.mean()) + fit = fitter(gauss, x, smcol) + #smoytrace[c] = fit.params[1] + smoytrace[c] = fit.mean.value + + return np.array(smoytrace)