Skip to content

Commit

Permalink
Create task to combine injection catalogs for matching
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmccann committed Jan 17, 2024
1 parent b5039a9 commit 03102d6
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/lsst/source/injection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
utils : Utility functions for injection tasks.
"""

from .combine import * # noqa: F401,F403
from .inject_base import * # noqa: F401,F403
from .inject_coadd import * # noqa: F401,F403
from .inject_engine import * # noqa: F401,F403
Expand Down
281 changes: 281 additions & 0 deletions python/lsst/source/injection/combine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
# This file is part of source_injection.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

__all__ = [
"CombineInjectionCatalogsConnections",
"CombineInjectionCatalogsConfig",
"CombineInjectionCatalogsTask",
]
from typing import cast

import numpy as np
from astropy.table import Table, vstack
from lsst.geom import SpherePoint, degrees
from lsst.pex.config import Field
from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
from lsst.pipe.base.connections import InputQuantizedConnection
from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput
from lsst.skymap import BaseSkyMap
from lsst.skymap.ringsSkyMap import RingsSkyMap
from smatch.matcher import Matcher


class CombineInjectionCatalogsConnections(
PipelineTaskConnections,
dimensions=("instrument", "skymap", "tract"),
defaultTemplates={
"injection_prefix": "injection_",
"injected_prefix": "injected_",
},
):
"""Base connections for source injection tasks."""

catalogs = PrerequisiteInput(
doc="Set of catalogs of sources to draw inputs from.",
name="{injected_prefix}deepCoadd_catalog",
dimensions=("skymap", "tract", "patch", "band"),
storageClass="ArrowAstropy",
minimum=0,
multiple=True,
)
skyMap = Input(
doc="Input definition of geometry/bbox and projection/wcs for warped exposures",
name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
storageClass="SkyMap",
dimensions=("skymap",),
)
output_catalog = Output(
doc="Combined multiband catalog of injected sources.",
name="combined_{injected_prefix}deepCoadd_catalog",
storageClass="ArrowAstropy",
dimensions=("skymap", "tract", "patch"),
)


class CombineInjectionCatalogsConfig( # type: ignore [call-arg]
PipelineTaskConfig, pipelineConnections=CombineInjectionCatalogsConnections
):
"""Base configuration for source injection tasks."""

col_ra = Field[str](
doc="Column name for right ascension (in degrees).",
default="ra",
)
col_dec = Field[str](
doc="Column name for declination (in degrees).",
default="dec",
)
col_mag = Field[str](
doc="Column name for magnitude.",
default="mag",
)
col_source_type = Field[str](
doc="Column name for the source type used in the input catalog. Must match one of the surface "
"brightness profiles defined by GalSim.",
default="source_type",
)


class CombineInjectionCatalogsTask(PipelineTask):
"""Class for combining all tables in a collection of injection_catalogs
into one table.
"""

_DefaultName = "combineInjectionCatalogsTask"
ConfigClass = CombineInjectionCatalogsConfig

def runQuantum(self, butler_quantum_context, input_refs, output_refs):
inputs = butler_quantum_context.get(input_refs)
catalog_dict, tract = self._get_catalogs(inputs, input_refs)
outputs = self.run(catalog_dict, inputs["skyMap"], tract)
butler_quantum_context.put(outputs, output_refs)

def _get_catalogs(
self,
inputs: dict,
input_refs: InputQuantizedConnection,
) -> tuple[dict, int]:
"""Organize all catalogs in a dictionary with photometry band keys."""
catalog_dict = {}
tracts = set()
for ref, catalog in zip(input_refs.catalogs, inputs["catalogs"]):
band = ref.dataId.band.name
if band not in catalog_dict:
catalog_dict[band] = []
# load the patch number to check for patch overlap duplicates later
catalog["patch"] = ref.dataId.patch.id
catalog_dict[band].append(catalog)
tracts.add(ref.dataId.tract.id)
# vstack all the catalogs for each band
for band, catalog_list in catalog_dict.items():
catalog_dict[band] = vstack(catalog_list)
if len(tracts) != 1:
raise RuntimeError(f"len({tracts=}) != 1")
return (catalog_dict, list(tracts)[0])

def _remove_duplicates(
self,
catalog: Table,
tractInfo,
) -> Table:
"""Remove tract and patch overlap duplicates."""
self.config = cast(CombineInjectionCatalogsConfig, self.config)
# TODO: add isPatchInner column
duplicates = []
for ind, row in enumerate(catalog):
tractInner = tractInfo.getInnerSkyPolygon()
# check if the source is within inner tract bounds
if tractInner.contains(np.deg2rad(row[self.config.col_ra]), np.deg2rad(row[self.config.col_dec])):
spherePoint = SpherePoint(
row[self.config.col_ra] * degrees, row[self.config.col_dec] * degrees
)
# get the patch number from the source's ra,dec
patchInfo = tractInfo.findPatch(spherePoint)
# check against the patch column
if row["patch"] != patchInfo.sequential_index:
duplicates.append(ind)
# if the source is not within inner tract bounds, remove it
else:
duplicates.append(ind)
catalog.remove_rows(duplicates)
return catalog

def _make_multiband_catalog(
self,
bands: list,
catalog_dict: dict,
match_radius: float,
) -> Table:
"""Combine multiple band-specific catalogs into one multiband
catalog.
"""
self.config = cast(CombineInjectionCatalogsConfig, self.config)
# load the first catalog then loop to add info for the other bands
multiband_catalog = catalog_dict[bands[0]]
multiband_catalog.rename_column(self.config.col_mag, f"{bands[0]}_mag")
for band in bands[1:]:
# make a masked column for the new band
multiband_catalog.add_column(np.ma.masked_all(len(multiband_catalog)), name=f"{band}_mag")
# match the input catalog for this band to the existing
# multiband catalog
catalog_next_band = catalog_dict[band]
catalog_next_band.rename_column(self.config.col_mag, f"{band}_mag")
with Matcher(multiband_catalog[self.config.col_ra], multiband_catalog[self.config.col_dec]) as m:
idx, multiband_match_inds, next_band_match_inds, dists = m.query_radius(
catalog_next_band[self.config.col_ra],
catalog_next_band[self.config.col_dec],
match_radius,
return_indices=True,
)
# if there are matches...
if len(multiband_match_inds) > 0 and len(next_band_match_inds) > 0:
# choose the coordinates in the brightest band
for i, j in zip(multiband_match_inds, next_band_match_inds):
mags = []
for col in multiband_catalog.colnames:
if "_mag" in col:
mags.append((col, multiband_catalog[i][col]))
bright_mag = min([x[1] for x in mags])
if catalog_next_band[f"{band}_mag"][j] < bright_mag:
multiband_catalog["ra"][i] = catalog_next_band["ra"][j]
multiband_catalog["dec"][i] = catalog_next_band["dec"][j]
# pick the source_type which shows multiple components if
# one exisits
if "+" in catalog_next_band[self.config.col_source_type][next_band_match_inds]:
multiband_catalog[self.config.col_source_type][multiband_match_inds] = catalog_next_band[
self.config.col_source_type
][next_band_match_inds]
# remove the mask and fill the new mag value
multiband_catalog.mask[f"{band}_mag"][multiband_match_inds] = False
multiband_catalog[f"{band}_mag"][multiband_match_inds] = catalog_next_band[f"{band}_mag"][
next_band_match_inds
]
# add rows for all the sources without matches
not_next_band_match_inds = np.full(len(catalog_next_band), True, dtype=bool)
not_next_band_match_inds[next_band_match_inds] = False
multiband_catalog = vstack([multiband_catalog, catalog_next_band[not_next_band_match_inds]])
# otherwise just stack the tables
else:
multiband_catalog = vstack([multiband_catalog, catalog_next_band])
return multiband_catalog

def run(
self,
catalog_dict: dict,
skymap: RingsSkyMap,
tract: int,
pixel_match_radius: float = 0.1,
) -> Table:
"""Combine all tables in catalog_dict into one table.
Parameters
----------
catalog_dict: `dict`
A dictionary with photometric bands for keys and astropy tables for
items.
skymap: `lsst.skymap.ringsSkyMap.RingsSkyMap`
A rings skymap.
tract: `int`
The tract where sources have been injected.
pixel_match_radius: `float`
Match radius in pixels to use for self-matching catalogs across
different bands.
Returns
-------
output_struct : `lsst.pipe.base.Struct`
contains :
multiband_catalog: `astropy.table.Table`
A single table containing all information of the separate
tables in catalog_dict
"""
self.config = cast(CombineInjectionCatalogsConfig, self.config)
bands = list(catalog_dict.keys())
# convert the pixel match radius to degrees
tractInfo = skymap.generateTract(tract)
tractWcs = tractInfo.getWcs()
pixel_scale = tractWcs.getPixelScale()
match_radius = pixel_match_radius * pixel_scale.asDegrees()
for band in bands:
catalog = catalog_dict[band]
# remove tract and patch overlap duplicates
catalog = self._remove_duplicates(catalog, tractInfo)

if len(bands) > 1:
# match the catalogs across bands
output_catalog = self._make_multiband_catalog(bands, catalog_dict, match_radius)
else:
output_catalog = catalog_dict[bands[0]]
output_catalog.rename_column(self.config.col_mag, f"{bands[0]}_mag")
# remove sources outside tract boundaries
for index, (ra, dec) in enumerate(
list(zip(output_catalog[self.config.col_ra], output_catalog[self.config.col_dec]))
):
point = SpherePoint(ra * degrees, dec * degrees)
if not tractInfo.contains(point):
output_catalog.remove_row(index)
# replace injection_id column with a new injected_id column
output_catalog["injection_id"] = list(range(len(output_catalog)))
output_catalog.rename_column("injection_id", "injected_id")
output_struct = Struct(output_catalog=output_catalog)
return output_struct

0 comments on commit 03102d6

Please sign in to comment.