-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create task to combine injection catalogs for matching
- Loading branch information
Showing
2 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,281 @@ | ||
# This file is part of source_injection. | ||
# | ||
# Developed for the LSST Data Management System. | ||
# This product includes software developed by the LSST Project | ||
# (https://www.lsst.org). | ||
# See the COPYRIGHT file at the top-level directory of this distribution | ||
# for details of code ownership. | ||
# | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"CombineInjectionCatalogsConnections", | ||
"CombineInjectionCatalogsConfig", | ||
"CombineInjectionCatalogsTask", | ||
] | ||
from typing import cast | ||
|
||
import numpy as np | ||
from astropy.table import Table, vstack | ||
from lsst.geom import SpherePoint, degrees | ||
from lsst.pex.config import Field | ||
from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct | ||
from lsst.pipe.base.connections import InputQuantizedConnection | ||
from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput | ||
from lsst.skymap import BaseSkyMap | ||
from lsst.skymap.ringsSkyMap import RingsSkyMap | ||
from smatch.matcher import Matcher | ||
|
||
|
||
class CombineInjectionCatalogsConnections( | ||
PipelineTaskConnections, | ||
dimensions=("instrument", "skymap", "tract"), | ||
defaultTemplates={ | ||
"injection_prefix": "injection_", | ||
"injected_prefix": "injected_", | ||
}, | ||
): | ||
"""Base connections for source injection tasks.""" | ||
|
||
catalogs = PrerequisiteInput( | ||
doc="Set of catalogs of sources to draw inputs from.", | ||
name="{injected_prefix}deepCoadd_catalog", | ||
dimensions=("skymap", "tract", "patch", "band"), | ||
storageClass="ArrowAstropy", | ||
minimum=0, | ||
multiple=True, | ||
) | ||
skyMap = Input( | ||
doc="Input definition of geometry/bbox and projection/wcs for warped exposures", | ||
name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, | ||
storageClass="SkyMap", | ||
dimensions=("skymap",), | ||
) | ||
output_catalog = Output( | ||
doc="Combined multiband catalog of injected sources.", | ||
name="combined_{injected_prefix}deepCoadd_catalog", | ||
storageClass="ArrowAstropy", | ||
dimensions=("skymap", "tract", "patch"), | ||
) | ||
|
||
|
||
class CombineInjectionCatalogsConfig( # type: ignore [call-arg] | ||
PipelineTaskConfig, pipelineConnections=CombineInjectionCatalogsConnections | ||
): | ||
"""Base configuration for source injection tasks.""" | ||
|
||
col_ra = Field[str]( | ||
doc="Column name for right ascension (in degrees).", | ||
default="ra", | ||
) | ||
col_dec = Field[str]( | ||
doc="Column name for declination (in degrees).", | ||
default="dec", | ||
) | ||
col_mag = Field[str]( | ||
doc="Column name for magnitude.", | ||
default="mag", | ||
) | ||
col_source_type = Field[str]( | ||
doc="Column name for the source type used in the input catalog. Must match one of the surface " | ||
"brightness profiles defined by GalSim.", | ||
default="source_type", | ||
) | ||
|
||
|
||
class CombineInjectionCatalogsTask(PipelineTask): | ||
"""Class for combining all tables in a collection of injection_catalogs | ||
into one table. | ||
""" | ||
|
||
_DefaultName = "combineInjectionCatalogsTask" | ||
ConfigClass = CombineInjectionCatalogsConfig | ||
|
||
def runQuantum(self, butler_quantum_context, input_refs, output_refs): | ||
inputs = butler_quantum_context.get(input_refs) | ||
catalog_dict, tract = self._get_catalogs(inputs, input_refs) | ||
outputs = self.run(catalog_dict, inputs["skyMap"], tract) | ||
butler_quantum_context.put(outputs, output_refs) | ||
|
||
def _get_catalogs( | ||
self, | ||
inputs: dict, | ||
input_refs: InputQuantizedConnection, | ||
) -> tuple[dict, int]: | ||
"""Organize all catalogs in a dictionary with photometry band keys.""" | ||
catalog_dict = {} | ||
tracts = set() | ||
for ref, catalog in zip(input_refs.catalogs, inputs["catalogs"]): | ||
band = ref.dataId.band.name | ||
if band not in catalog_dict: | ||
catalog_dict[band] = [] | ||
# load the patch number to check for patch overlap duplicates later | ||
catalog["patch"] = ref.dataId.patch.id | ||
catalog_dict[band].append(catalog) | ||
tracts.add(ref.dataId.tract.id) | ||
# vstack all the catalogs for each band | ||
for band, catalog_list in catalog_dict.items(): | ||
catalog_dict[band] = vstack(catalog_list) | ||
if len(tracts) != 1: | ||
raise RuntimeError(f"len({tracts=}) != 1") | ||
return (catalog_dict, list(tracts)[0]) | ||
|
||
def _remove_duplicates( | ||
self, | ||
catalog: Table, | ||
tractInfo, | ||
) -> Table: | ||
"""Remove tract and patch overlap duplicates.""" | ||
self.config = cast(CombineInjectionCatalogsConfig, self.config) | ||
# TODO: add isPatchInner column | ||
duplicates = [] | ||
for ind, row in enumerate(catalog): | ||
tractInner = tractInfo.getInnerSkyPolygon() | ||
# check if the source is within inner tract bounds | ||
if tractInner.contains(np.deg2rad(row[self.config.col_ra]), np.deg2rad(row[self.config.col_dec])): | ||
spherePoint = SpherePoint( | ||
row[self.config.col_ra] * degrees, row[self.config.col_dec] * degrees | ||
) | ||
# get the patch number from the source's ra,dec | ||
patchInfo = tractInfo.findPatch(spherePoint) | ||
# check against the patch column | ||
if row["patch"] != patchInfo.sequential_index: | ||
duplicates.append(ind) | ||
# if the source is not within inner tract bounds, remove it | ||
else: | ||
duplicates.append(ind) | ||
catalog.remove_rows(duplicates) | ||
return catalog | ||
|
||
def _make_multiband_catalog( | ||
self, | ||
bands: list, | ||
catalog_dict: dict, | ||
match_radius: float, | ||
) -> Table: | ||
"""Combine multiple band-specific catalogs into one multiband | ||
catalog. | ||
""" | ||
self.config = cast(CombineInjectionCatalogsConfig, self.config) | ||
# load the first catalog then loop to add info for the other bands | ||
multiband_catalog = catalog_dict[bands[0]] | ||
multiband_catalog.rename_column(self.config.col_mag, f"{bands[0]}_mag") | ||
for band in bands[1:]: | ||
# make a masked column for the new band | ||
multiband_catalog.add_column(np.ma.masked_all(len(multiband_catalog)), name=f"{band}_mag") | ||
# match the input catalog for this band to the existing | ||
# multiband catalog | ||
catalog_next_band = catalog_dict[band] | ||
catalog_next_band.rename_column(self.config.col_mag, f"{band}_mag") | ||
with Matcher(multiband_catalog[self.config.col_ra], multiband_catalog[self.config.col_dec]) as m: | ||
idx, multiband_match_inds, next_band_match_inds, dists = m.query_radius( | ||
catalog_next_band[self.config.col_ra], | ||
catalog_next_band[self.config.col_dec], | ||
match_radius, | ||
return_indices=True, | ||
) | ||
# if there are matches... | ||
if len(multiband_match_inds) > 0 and len(next_band_match_inds) > 0: | ||
# choose the coordinates in the brightest band | ||
for i, j in zip(multiband_match_inds, next_band_match_inds): | ||
mags = [] | ||
for col in multiband_catalog.colnames: | ||
if "_mag" in col: | ||
mags.append((col, multiband_catalog[i][col])) | ||
bright_mag = min([x[1] for x in mags]) | ||
if catalog_next_band[f"{band}_mag"][j] < bright_mag: | ||
multiband_catalog["ra"][i] = catalog_next_band["ra"][j] | ||
multiband_catalog["dec"][i] = catalog_next_band["dec"][j] | ||
# pick the source_type which shows multiple components if | ||
# one exisits | ||
if "+" in catalog_next_band[self.config.col_source_type][next_band_match_inds]: | ||
multiband_catalog[self.config.col_source_type][multiband_match_inds] = catalog_next_band[ | ||
self.config.col_source_type | ||
][next_band_match_inds] | ||
# remove the mask and fill the new mag value | ||
multiband_catalog.mask[f"{band}_mag"][multiband_match_inds] = False | ||
multiband_catalog[f"{band}_mag"][multiband_match_inds] = catalog_next_band[f"{band}_mag"][ | ||
next_band_match_inds | ||
] | ||
# add rows for all the sources without matches | ||
not_next_band_match_inds = np.full(len(catalog_next_band), True, dtype=bool) | ||
not_next_band_match_inds[next_band_match_inds] = False | ||
multiband_catalog = vstack([multiband_catalog, catalog_next_band[not_next_band_match_inds]]) | ||
# otherwise just stack the tables | ||
else: | ||
multiband_catalog = vstack([multiband_catalog, catalog_next_band]) | ||
return multiband_catalog | ||
|
||
def run( | ||
self, | ||
catalog_dict: dict, | ||
skymap: RingsSkyMap, | ||
tract: int, | ||
pixel_match_radius: float = 0.1, | ||
) -> Table: | ||
"""Combine all tables in catalog_dict into one table. | ||
Parameters | ||
---------- | ||
catalog_dict: `dict` | ||
A dictionary with photometric bands for keys and astropy tables for | ||
items. | ||
skymap: `lsst.skymap.ringsSkyMap.RingsSkyMap` | ||
A rings skymap. | ||
tract: `int` | ||
The tract where sources have been injected. | ||
pixel_match_radius: `float` | ||
Match radius in pixels to use for self-matching catalogs across | ||
different bands. | ||
Returns | ||
------- | ||
output_struct : `lsst.pipe.base.Struct` | ||
contains : | ||
multiband_catalog: `astropy.table.Table` | ||
A single table containing all information of the separate | ||
tables in catalog_dict | ||
""" | ||
self.config = cast(CombineInjectionCatalogsConfig, self.config) | ||
bands = list(catalog_dict.keys()) | ||
# convert the pixel match radius to degrees | ||
tractInfo = skymap.generateTract(tract) | ||
tractWcs = tractInfo.getWcs() | ||
pixel_scale = tractWcs.getPixelScale() | ||
match_radius = pixel_match_radius * pixel_scale.asDegrees() | ||
for band in bands: | ||
catalog = catalog_dict[band] | ||
# remove tract and patch overlap duplicates | ||
catalog = self._remove_duplicates(catalog, tractInfo) | ||
|
||
if len(bands) > 1: | ||
# match the catalogs across bands | ||
output_catalog = self._make_multiband_catalog(bands, catalog_dict, match_radius) | ||
else: | ||
output_catalog = catalog_dict[bands[0]] | ||
output_catalog.rename_column(self.config.col_mag, f"{bands[0]}_mag") | ||
# remove sources outside tract boundaries | ||
for index, (ra, dec) in enumerate( | ||
list(zip(output_catalog[self.config.col_ra], output_catalog[self.config.col_dec])) | ||
): | ||
point = SpherePoint(ra * degrees, dec * degrees) | ||
if not tractInfo.contains(point): | ||
output_catalog.remove_row(index) | ||
# replace injection_id column with a new injected_id column | ||
output_catalog["injection_id"] = list(range(len(output_catalog))) | ||
output_catalog.rename_column("injection_id", "injected_id") | ||
output_struct = Struct(output_catalog=output_catalog) | ||
return output_struct |