From 72a695d5467e96c13958da3ebbf3fd54750a9365 Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Wed, 19 Apr 2023 11:09:11 -0700 Subject: [PATCH] Generalize amp offset code --- python/lsst/ip/isr/ampOffset.py | 242 ++++++++++++++++++++++++++++---- python/lsst/ip/isr/isrTask.py | 2 +- 2 files changed, 215 insertions(+), 29 deletions(-) diff --git a/python/lsst/ip/isr/ampOffset.py b/python/lsst/ip/isr/ampOffset.py index cc59c6ffb..ef51a79b5 100644 --- a/python/lsst/ip/isr/ampOffset.py +++ b/python/lsst/ip/isr/ampOffset.py @@ -18,88 +18,274 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . -# import os __all__ = ["AmpOffsetConfig", "AmpOffsetTask"] -import lsst.pex.config as pexConfig -import lsst.pipe.base as pipeBase -from lsst.meas.algorithms import (SubtractBackgroundTask, SourceDetectionTask) +import warnings +import numpy as np +from lsst.afw.math import MEANCLIP, StatisticsControl, makeStatistics +from lsst.afw.table import SourceTable +from lsst.meas.algorithms import SourceDetectionTask, SubtractBackgroundTask +from lsst.pex.config import Config, ConfigurableField, Field +from lsst.pipe.base import Task -class AmpOffsetConfig(pexConfig.Config): - """Configuration parameters for AmpOffsetTask. - """ - ampEdgeInset = pexConfig.Field( + +class AmpOffsetConfig(Config): + """Configuration parameters for AmpOffsetTask.""" + + ampEdgeInset = Field( doc="Number of pixels the amp edge strip is inset from the amp edge. A thin strip of pixels running " "parallel to the edge of the amp is used to characterize the average flux level at the amp edge.", dtype=int, default=5, ) - ampEdgeWidth = pexConfig.Field( + ampEdgeWidth = Field( doc="Pixel width of the amp edge strip, starting at ampEdgeInset and extending inwards.", dtype=int, default=64, ) - ampEdgeMinFrac = pexConfig.Field( + ampEdgeMinFrac = Field( doc="Minimum allowed fraction of viable pixel rows along an amp edge. No amp offset estimate will be " "generated for amp edges that do not have at least this fraction of unmasked pixel rows.", dtype=float, default=0.5, ) - ampEdgeMaxOffset = pexConfig.Field( + ampEdgeMaxOffset = Field( doc="Maximum allowed amp offset ADU value. If a measured amp offset value is larger than this, the " "result will be discarded and therefore not used to determine amp pedestal corrections.", dtype=float, default=5.0, ) - ampEdgeWindow = pexConfig.Field( + ampEdgeWindow = Field( doc="Pixel size of the sliding window used to generate rolling average amp offset values.", dtype=int, - default=512, + default=512, # we probably need to reconsider this (e.g. for lsst) ) - doBackground = pexConfig.Field( + doBackground = Field( doc="Estimate and subtract background prior to amp offset estimation?", dtype=bool, default=True, ) - background = pexConfig.ConfigurableField( + background = ConfigurableField( doc="An initial background estimation step run prior to amp offset calculation.", target=SubtractBackgroundTask, ) - doDetection = pexConfig.Field( + doDetection = Field( doc="Detect sources and update cloned exposure prior to amp offset estimation?", dtype=bool, default=True, ) - detection = pexConfig.ConfigurableField( + detection = ConfigurableField( doc="Source detection to add temporary detection footprints prior to amp offset calculation.", target=SourceDetectionTask, ) -class AmpOffsetTask(pipeBase.Task): - """Calculate and apply amp offset corrections to an exposure. - """ +class AmpOffsetTask(Task): + """Calculate and apply amp offset corrections to an exposure.""" + ConfigClass = AmpOffsetConfig _DefaultName = "isrAmpOffset" def __init__(self, *args, **kwargs): - super().__init__(**kwargs) - # always load background subtask, even if doBackground=False; - # this allows for default plane bit masks to be defined + super().__init__(*args, **kwargs) + # Always load background subtask, even if doBackground=False; + # this allows for default plane bit masks to be defined. self.makeSubtask("background") if self.config.doDetection: self.makeSubtask("detection") def run(self, exposure): """Calculate amp offset values, determine corrective pedestals for each - amp, and update the input exposure in-place. This task is currently not - implemented, and should be retargeted by a camera specific version. + amp, and update the input exposure in-place. Parameters ---------- - exposure : `lsst.afw.image.Exposure` - Exposure to be corrected for any amp offsets. + exposure: `lsst.afw.image.Exposure` + Exposure to be corrected for amp offsets. """ - raise NotImplementedError("Amp offset task should be retargeted by a camera specific version.") + + # Generate an exposure clone to work on and establish the bit mask. + exp = exposure.clone() + bitMask = exp.mask.getPlaneBitMask(self.background.config.ignoredPixelMask) + self.log.info( + "Ignored mask planes for amp offset estimation: [%s].", + ", ".join(self.background.config.ignoredPixelMask), + ) + + # Fit and subtract background. + if self.config.doBackground: + maskedImage = exp.getMaskedImage() + bg = self.background.fitBackground(maskedImage) + bgImage = bg.getImageF(self.background.config.algorithm, self.background.config.undersampleStyle) + maskedImage -= bgImage + + # Detect sources and update cloned exposure mask planes in-place. + if self.config.doDetection: + schema = SourceTable.makeMinimalSchema() + table = SourceTable.make(schema) + # Detection sigma, used for smoothing and to grow detections, is + # normally measured from the PSF of the exposure. As the PSF hasn't + # been measured at this stage of processing, sigma is instead + # set to an approximate value here (which should be sufficient). + _ = self.detection.run(table=table, exposure=exp, sigma=2) + + # Safety check: do any pixels remain for amp offset estimation? + if (exp.mask.array & bitMask).all(): + self.log.warning("All pixels masked: cannot calculate any amp offset corrections.") + else: + # Set up amp offset inputs. + im = exp.image + im.array[(exp.mask.array & bitMask) > 0] = np.nan + amps = exp.getDetector().getAmplifiers() + + # Determine amplifier interface geometry. + ampAreas = {amp.getBBox().getArea() for amp in amps} + if len(ampAreas) > 1: + raise NotImplementedError( + "Amp offset correction is not yet implemented for detectors with differing amp sizes." + ) + + A, interfaces = self.getAmpAssociations(amps) + B = self.getAmpOffsets(im, amps, A, interfaces) + + # if least-squares minimization fails, convert NaNs to zeroes, + # ensuring that no values are erroneously added/subtracted + pedestals = np.nan_to_num(np.linalg.lstsq(A, B, rcond=None)[0]) + metadata = exposure.getMetadata() + for ii, (amp, pedestal) in enumerate(zip(amps, pedestals)): + ampIm = exposure.image[amp.getBBox()].array + ampIm -= pedestal + metadata.set( + f"PEDESTAL{ii + 1}", float(pedestal), f"Pedestal level subtracted from amp {ii + 1}" + ) + self.log.info(f"amp pedestal values: {', '.join([f'{x:.2f}' for x in pedestals])}") + + # raise NotImplementedError("Amp offset task should be retargeted + # by a camera specific version.") + + def getAmpAssociations(self, amps): + """Determine amp geometry and amp associations from a list of amplifiers. + + Parse an input list of amplifiers to determine the layout of amps + within a detector, and identify all amp interfaces (i.e., the + horizontal and vertical junctions between amps). + + Returns a matrix with a shape corresponding to the geometry of the amps + in the detector. + + Parameters + ---------- + amps: `list` [`lsst.afw.cameraGeom.Amplifier`] + List of amplifier objects. + + Returns + ------- + ampAssociations: `numpy.ndarray` + Matrix with amp interface association information. + """ + xCenters = [amp.getBBox().getCenterX() for amp in amps] + yCenters = [amp.getBBox().getCenterY() for amp in amps] + xIndices = np.ceil(xCenters / np.min(xCenters) / 2).astype(int) - 1 + yIndices = np.ceil(yCenters / np.min(yCenters) / 2).astype(int) - 1 + + nAmps = len(amps) + ampIds = np.zeros((len(set(yIndices)), len(set(xIndices))), dtype=int) + + for ampId, xIndex, yIndex in zip(np.arange(nAmps), xIndices, yIndices): + ampIds[yIndex, xIndex] = ampId + + # ampIds = np.array([[0, 1, 2, 3, 4, 5, 6, 7],[15, 14, 13, 12, 11, 10, 9, 8]]) # LSST!! + + ampAssociations = np.zeros((nAmps, nAmps), dtype=int) + ampInterfaces = np.full_like(ampAssociations, -1) + + for ampId in ampIds.ravel(): + adjacents, interfaces = self.getAdjacents(ampIds, ampId) + ampAssociations[ampId, adjacents] = -1 + ampInterfaces[ampId, adjacents] = interfaces + ampAssociations[ampId, ampId] = -ampAssociations[ampId].sum() + + if ampAssociations.sum() != 0: + raise RuntimeError("The `ampAssociations` array does not sum to zero.") + + return ampAssociations, ampInterfaces + + def getAdjacents(self, ampIds, ampId): + m, n = ampIds.shape + r, c = np.ravel(np.where(ampIds == ampId)) + adjacents, interfaces = [], [] + interfaceLookup = { + 0: (r + 1, c), + 1: (r, c + 1), + 2: (r - 1, c), + 3: (r, c - 1), + } + for interface, (row, column) in interfaceLookup.items(): + if 0 <= row < m and 0 <= column < n: + adjacents.append(ampIds[row][column]) + interfaces.append(interface) + return adjacents, interfaces + + def getInterfaceAmpOffset(self, ampIdA, ampIdB, edgeA, edgeB): + sctrl = StatisticsControl() + edgeDiff = edgeA - edgeB # edgeB - edgeA # !!! needed to change this to fix the sign flip + # compute rolling averages + edgeDiffSum = np.convolve(np.nan_to_num(edgeDiff), np.ones(self.config.ampEdgeWindow), "same") + edgeDiffNum = np.convolve(~np.isnan(edgeDiff), np.ones(self.config.ampEdgeWindow), "same") + edgeDiffAvg = edgeDiffSum / np.clip(edgeDiffNum, 1, None) + edgeDiffAvg[np.isnan(edgeDiff)] = np.nan + # take clipped mean of rolling average data as amp offset value + ampOffset = makeStatistics(edgeDiffAvg, MEANCLIP, sctrl).getValue() + # perform a couple of do-no-harm safety checks: + # a) the fraction of unmasked pixel rows is > ampEdgeMinFrac, + # b) the absolute offset ADU value is < ampEdgeMaxOffset + ampEdgeGoodFrac = 1 - (np.sum(np.isnan(edgeDiffAvg)) / len(edgeDiffAvg)) + minFracFail = ampEdgeGoodFrac < self.config.ampEdgeMinFrac + maxOffsetFail = np.abs(ampOffset) > self.config.ampEdgeMaxOffset + if minFracFail or maxOffsetFail: + ampOffset = 0 + self.log.debug( + f"amp edge {ampIdA}{ampIdB} : " + f"viable edge frac = {ampEdgeGoodFrac}, " + f"edge offset = {ampOffset:.4f}" + ) + return ampOffset + + def getAmpOffsets(self, im, amps, associations, interfaces): + ampsOffsets = np.zeros(len(amps)) + ampsEdges = self.getAmpEdges(im, amps, interfaces) + for ampId, ampAssociations in enumerate(associations): + ampAdjacents = np.ravel(np.where(ampAssociations < 0)) + for ampAdjacent in ampAdjacents: + ampInterface = interfaces[ampId][ampAdjacent] + edgeA = ampsEdges[ampId][ampInterface] + edgeB = ampsEdges[ampAdjacent][(ampInterface + 2) % 4] + ampsOffsets[ampId] += self.getInterfaceAmpOffset(ampId, ampAdjacent, edgeA, edgeB) + return ampsOffsets + + def getAmpEdges(self, im, amps, ampInterfaces): + ampEdgeOuter = self.config.ampEdgeInset + self.config.ampEdgeWidth + ampEdges = {} + slice_map = { + 0: (slice(-ampEdgeOuter, -self.config.ampEdgeInset), slice(None)), + 1: (slice(None), slice(-ampEdgeOuter, -self.config.ampEdgeInset)), + 2: (slice(self.config.ampEdgeInset, ampEdgeOuter), slice(None)), + 3: (slice(None), slice(self.config.ampEdgeInset, ampEdgeOuter)), + } + for ampId, (amp, ampInterfaces) in enumerate(zip(amps, ampInterfaces)): + ampEdges[ampId] = {} + ampIm = im[amp.getBBox()].array + # loop over identified interfaces + for ampInterface in ampInterfaces: + if ampInterface < 0: + continue + strip = ampIm[slice_map[ampInterface]] + # catch warnings to prevent all-NaN slice RuntimeWarning + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") + ampEdges[ampId][ampInterface] = np.nanmedian( + strip, axis=ampInterface % 2 + ) # 1D medianified strip + return ampEdges diff --git a/python/lsst/ip/isr/isrTask.py b/python/lsst/ip/isr/isrTask.py index a73dc6cd4..f2469f8dc 100644 --- a/python/lsst/ip/isr/isrTask.py +++ b/python/lsst/ip/isr/isrTask.py @@ -2620,4 +2620,4 @@ def getSaturation(self): return self._saturation def getSuspectLevel(self): - return float("NaN") + return float("NaN") \ No newline at end of file