Skip to content

Commit

Permalink
Generalize amp offset code
Browse files Browse the repository at this point in the history
  • Loading branch information
leeskelvin authored and enourbakhsh committed May 31, 2023
1 parent 98536aa commit 4fec803
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 29 deletions.
242 changes: 214 additions & 28 deletions python/lsst/ip/isr/ampOffset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,88 +18,274 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# import os

__all__ = ["AmpOffsetConfig", "AmpOffsetTask"]

import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.meas.algorithms import (SubtractBackgroundTask, SourceDetectionTask)
import warnings

import numpy as np
from lsst.afw.math import MEANCLIP, StatisticsControl, makeStatistics
from lsst.afw.table import SourceTable
from lsst.meas.algorithms import SourceDetectionTask, SubtractBackgroundTask
from lsst.pex.config import Config, ConfigurableField, Field
from lsst.pipe.base import Task

class AmpOffsetConfig(pexConfig.Config):
"""Configuration parameters for AmpOffsetTask.
"""
ampEdgeInset = pexConfig.Field(

class AmpOffsetConfig(Config):
"""Configuration parameters for AmpOffsetTask."""

ampEdgeInset = Field(
doc="Number of pixels the amp edge strip is inset from the amp edge. A thin strip of pixels running "
"parallel to the edge of the amp is used to characterize the average flux level at the amp edge.",
dtype=int,
default=5,
)
ampEdgeWidth = pexConfig.Field(
ampEdgeWidth = Field(
doc="Pixel width of the amp edge strip, starting at ampEdgeInset and extending inwards.",
dtype=int,
default=64,
)
ampEdgeMinFrac = pexConfig.Field(
ampEdgeMinFrac = Field(
doc="Minimum allowed fraction of viable pixel rows along an amp edge. No amp offset estimate will be "
"generated for amp edges that do not have at least this fraction of unmasked pixel rows.",
dtype=float,
default=0.5,
)
ampEdgeMaxOffset = pexConfig.Field(
ampEdgeMaxOffset = Field(
doc="Maximum allowed amp offset ADU value. If a measured amp offset value is larger than this, the "
"result will be discarded and therefore not used to determine amp pedestal corrections.",
dtype=float,
default=5.0,
)
ampEdgeWindow = pexConfig.Field(
ampEdgeWindow = Field(
doc="Pixel size of the sliding window used to generate rolling average amp offset values.",
dtype=int,
default=512,
default=512, # we probably need to reconsider this (e.g. for lsst)
)
doBackground = pexConfig.Field(
doBackground = Field(
doc="Estimate and subtract background prior to amp offset estimation?",
dtype=bool,
default=True,
)
background = pexConfig.ConfigurableField(
background = ConfigurableField(
doc="An initial background estimation step run prior to amp offset calculation.",
target=SubtractBackgroundTask,
)
doDetection = pexConfig.Field(
doDetection = Field(
doc="Detect sources and update cloned exposure prior to amp offset estimation?",
dtype=bool,
default=True,
)
detection = pexConfig.ConfigurableField(
detection = ConfigurableField(
doc="Source detection to add temporary detection footprints prior to amp offset calculation.",
target=SourceDetectionTask,
)


class AmpOffsetTask(pipeBase.Task):
"""Calculate and apply amp offset corrections to an exposure.
"""
class AmpOffsetTask(Task):
"""Calculate and apply amp offset corrections to an exposure."""

ConfigClass = AmpOffsetConfig
_DefaultName = "isrAmpOffset"

def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
# always load background subtask, even if doBackground=False;
# this allows for default plane bit masks to be defined
super().__init__(*args, **kwargs)
# Always load background subtask, even if doBackground=False;
# this allows for default plane bit masks to be defined.
self.makeSubtask("background")
if self.config.doDetection:
self.makeSubtask("detection")

def run(self, exposure):
"""Calculate amp offset values, determine corrective pedestals for each
amp, and update the input exposure in-place. This task is currently not
implemented, and should be retargeted by a camera specific version.
amp, and update the input exposure in-place.
Parameters
----------
exposure : `lsst.afw.image.Exposure`
Exposure to be corrected for any amp offsets.
exposure: `lsst.afw.image.Exposure`
Exposure to be corrected for amp offsets.
"""
raise NotImplementedError("Amp offset task should be retargeted by a camera specific version.")

# Generate an exposure clone to work on and establish the bit mask.
exp = exposure.clone()
bitMask = exp.mask.getPlaneBitMask(self.background.config.ignoredPixelMask)
self.log.info(
"Ignored mask planes for amp offset estimation: [%s].",
", ".join(self.background.config.ignoredPixelMask),
)

# Fit and subtract background.
if self.config.doBackground:
maskedImage = exp.getMaskedImage()
bg = self.background.fitBackground(maskedImage)
bgImage = bg.getImageF(self.background.config.algorithm, self.background.config.undersampleStyle)
maskedImage -= bgImage

# Detect sources and update cloned exposure mask planes in-place.
if self.config.doDetection:
schema = SourceTable.makeMinimalSchema()
table = SourceTable.make(schema)
# Detection sigma, used for smoothing and to grow detections, is
# normally measured from the PSF of the exposure. As the PSF hasn't
# been measured at this stage of processing, sigma is instead
# set to an approximate value here (which should be sufficient).
_ = self.detection.run(table=table, exposure=exp, sigma=2)

# Safety check: do any pixels remain for amp offset estimation?
if (exp.mask.array & bitMask).all():
self.log.warning("All pixels masked: cannot calculate any amp offset corrections.")
else:
# Set up amp offset inputs.
im = exp.image
im.array[(exp.mask.array & bitMask) > 0] = np.nan
amps = exp.getDetector().getAmplifiers()

# Determine amplifier interface geometry.
ampAreas = {amp.getBBox().getArea() for amp in amps}
if len(ampAreas) > 1:
raise NotImplementedError(
"Amp offset correction is not yet implemented for detectors with differing amp sizes."
)

A, interfaces = self.getAmpAssociations(amps)
B = self.getAmpOffsets(im, amps, A, interfaces)

# if least-squares minimization fails, convert NaNs to zeroes,
# ensuring that no values are erroneously added/subtracted
pedestals = np.nan_to_num(np.linalg.lstsq(A, B, rcond=None)[0])
metadata = exposure.getMetadata()
for ii, (amp, pedestal) in enumerate(zip(amps, pedestals)):
ampIm = exposure.image[amp.getBBox()].array
ampIm -= pedestal
metadata.set(
f"PEDESTAL{ii + 1}", float(pedestal), f"Pedestal level subtracted from amp {ii + 1}"
)
self.log.info(f"amp pedestal values: {', '.join([f'{x:.4f}' for x in pedestals])}")

# raise NotImplementedError("Amp offset task should be retargeted
# by a camera specific version.")

def getAmpAssociations(self, amps):
"""Determine amp geometry and amp associations from a list of amplifiers.
Parse an input list of amplifiers to determine the layout of amps
within a detector, and identify all amp interfaces (i.e., the
horizontal and vertical junctions between amps).
Returns a matrix with a shape corresponding to the geometry of the amps
in the detector.
Parameters
----------
amps: `list` [`lsst.afw.cameraGeom.Amplifier`]
List of amplifier objects.
Returns
-------
ampAssociations: `numpy.ndarray`
Matrix with amp interface association information.
"""
xCenters = [amp.getBBox().getCenterX() for amp in amps]
yCenters = [amp.getBBox().getCenterY() for amp in amps]
xIndices = np.ceil(xCenters / np.min(xCenters) / 2).astype(int) - 1
yIndices = np.ceil(yCenters / np.min(yCenters) / 2).astype(int) - 1

nAmps = len(amps)
ampIds = np.zeros((len(set(yIndices)), len(set(xIndices))), dtype=int)

for ampId, xIndex, yIndex in zip(np.arange(nAmps), xIndices, yIndices):
ampIds[yIndex, xIndex] = ampId

# ampIds = np.array([[0, 1, 2, 3, 4, 5, 6, 7],[15, 14, 13, 12, 11, 10, 9, 8]]) # LSST!!

ampAssociations = np.zeros((nAmps, nAmps), dtype=int)
ampInterfaces = np.full_like(ampAssociations, -1)

for ampId in ampIds.ravel():
adjacents, interfaces = self.getAdjacents(ampIds, ampId)
ampAssociations[ampId, adjacents] = -1
ampInterfaces[ampId, adjacents] = interfaces
ampAssociations[ampId, ampId] = -ampAssociations[ampId].sum()

if ampAssociations.sum() != 0:
raise RuntimeError("The `ampAssociations` array does not sum to zero.")

return ampAssociations, ampInterfaces

def getAdjacents(self, ampIds, ampId):
m, n = ampIds.shape
r, c = np.ravel(np.where(ampIds == ampId))
adjacents, interfaces = [], []
interfaceLookup = {
0: (r + 1, c),
1: (r, c + 1),
2: (r - 1, c),
3: (r, c - 1),
}
for interface, (row, column) in interfaceLookup.items():
if 0 <= row < m and 0 <= column < n:
adjacents.append(ampIds[row][column])
interfaces.append(interface)
return adjacents, interfaces

def getInterfaceAmpOffset(self, ampIdA, ampIdB, edgeA, edgeB):
sctrl = StatisticsControl()
edgeDiff = edgeA - edgeB # edgeB - edgeA # !!! needed to change this to fix the sign flip
# compute rolling averages
edgeDiffSum = np.convolve(np.nan_to_num(edgeDiff), np.ones(self.config.ampEdgeWindow), "same")
edgeDiffNum = np.convolve(~np.isnan(edgeDiff), np.ones(self.config.ampEdgeWindow), "same")
edgeDiffAvg = edgeDiffSum / np.clip(edgeDiffNum, 1, None)
edgeDiffAvg[np.isnan(edgeDiff)] = np.nan
# take clipped mean of rolling average data as amp offset value
ampOffset = makeStatistics(edgeDiffAvg, MEANCLIP, sctrl).getValue()
# perform a couple of do-no-harm safety checks:
# a) the fraction of unmasked pixel rows is > ampEdgeMinFrac,
# b) the absolute offset ADU value is < ampEdgeMaxOffset
ampEdgeGoodFrac = 1 - (np.sum(np.isnan(edgeDiffAvg)) / len(edgeDiffAvg))
minFracFail = ampEdgeGoodFrac < self.config.ampEdgeMinFrac
maxOffsetFail = np.abs(ampOffset) > self.config.ampEdgeMaxOffset
if minFracFail or maxOffsetFail:
ampOffset = 0
self.log.debug(
f"amp edge {ampIdA}{ampIdB} : "
f"viable edge frac = {ampEdgeGoodFrac}, "
f"edge offset = {ampOffset:.3f}"
)
return ampOffset

def getAmpOffsets(self, im, amps, associations, interfaces):
ampsOffsets = np.zeros(len(amps))
ampsEdges = self.getAmpEdges(im, amps, interfaces)
for ampId, ampAssociations in enumerate(associations):
ampAdjacents = np.ravel(np.where(ampAssociations < 0))
for ampAdjacent in ampAdjacents:
ampInterface = interfaces[ampId][ampAdjacent]
edgeA = ampsEdges[ampId][ampInterface]
edgeB = ampsEdges[ampAdjacent][(ampInterface + 2) % 4]
ampsOffsets[ampId] += self.getInterfaceAmpOffset(ampId, ampAdjacent, edgeA, edgeB)
return ampsOffsets

def getAmpEdges(self, im, amps, ampInterfaces):
ampEdgeOuter = self.config.ampEdgeInset + self.config.ampEdgeWidth
ampEdges = {}
slice_map = {
0: (slice(-ampEdgeOuter, -self.config.ampEdgeInset), slice(None)),
1: (slice(None), slice(-ampEdgeOuter, -self.config.ampEdgeInset)),
2: (slice(self.config.ampEdgeInset, ampEdgeOuter), slice(None)),
3: (slice(None), slice(self.config.ampEdgeInset, ampEdgeOuter)),
}
for ampId, (amp, ampInterfaces) in enumerate(zip(amps, ampInterfaces)):
ampEdges[ampId] = {}
ampIm = im[amp.getBBox()].array
# loop over identified interfaces
for ampInterface in ampInterfaces:
if ampInterface < 0:
continue
strip = ampIm[slice_map[ampInterface]]
# catch warnings to prevent all-NaN slice RuntimeWarning
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
ampEdges[ampId][ampInterface] = np.nanmedian(
strip, axis=ampInterface % 2
) # 1D medianified strip
return ampEdges
2 changes: 1 addition & 1 deletion python/lsst/ip/isr/isrTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2620,4 +2620,4 @@ def getSaturation(self):
return self._saturation

def getSuspectLevel(self):
return float("NaN")
return float("NaN")

0 comments on commit 4fec803

Please sign in to comment.