Skip to content

Commit

Permalink
Merge pull request #35 from lsst/tickets/DM-40441
Browse files Browse the repository at this point in the history
DM-40441: Avoid use of deprecated Pipeline.toExpandedPipeline()
  • Loading branch information
TallJimbo authored Apr 8, 2024
2 parents a9cb4f5 + 3ac26e5 commit 2b2b4ad
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 53 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 24.3.0
hooks:
- id: black
# It is recommended to specify the latest version of Python
Expand All @@ -15,12 +15,12 @@ repos:
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.10
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.277
rev: v0.3.4
hooks:
- id: ruff
72 changes: 39 additions & 33 deletions python/lsst/source/injection/utils/make_injection_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,13 @@

__all__ = ["make_injection_pipeline"]

import itertools
import logging

from lsst.analysis.tools.interfaces import AnalysisPipelineTask
from lsst.pipe.base import LabelSpecifier, Pipeline


def _get_dataset_type_names(conns, fields):
"""Return the name of a connection's dataset type."""
dataset_type_names = set()
for field in fields:
dataset_type_names.add(getattr(conns, field).name)
return dataset_type_names


def _parse_config_override(config_override: str) -> tuple[str, str, str]:
"""Parse a config override string into a label, a key and a value.
Expand Down Expand Up @@ -160,7 +153,7 @@ def make_injection_pipeline(
# Remove all tasks which are not to be included in the injection pipeline.
if isinstance(excluded_tasks, str):
excluded_tasks = set(excluded_tasks.split(","))
all_tasks = {taskDef.label for taskDef in pipeline.toExpandedPipeline()}
all_tasks = set(pipeline.task_labels)
preserved_tasks = all_tasks - excluded_tasks
label_specifier = LabelSpecifier(labels=preserved_tasks)
# EDIT mode removes tasks from parent subsets but keeps the subset itself.
Expand All @@ -179,46 +172,62 @@ def make_injection_pipeline(
injected_types = {dataset_type_name}
precursor_injection_task_labels = set()
# Loop over all tasks in the pipeline.
for taskDef in pipeline.toExpandedPipeline():
# Add override for Analysis Tools taskDefs. Connections in Analysis
for task_node in pipeline.to_graph().tasks.values():
# Add override for Analysis Tools tasks. Connections in Analysis
# Tools are dynamically assigned, and so are not able to be modified in
# the same way as a static connection. Instead, we add a config
# override here to the connections.outputName field. This field is
# prepended to all Analysis Tools connections, and so will prepend the
# injection prefix to all plot/metric outputs. Further processing of
# this taskDef will be skipped thereafter.
if issubclass(taskDef.taskClass, AnalysisPipelineTask):
# this task will be skipped thereafter.
if issubclass(task_node.task_class, AnalysisPipelineTask):
pipeline.addConfigOverride(
taskDef.label, "connections.outputName", prefix + taskDef.config.connections.outputName
task_node.label,
"connections.outputName",
prefix + task_node.config.connections.outputName,
)
continue

conns = taskDef.connections
input_types = _get_dataset_type_names(conns, conns.initInputs | conns.inputs)
output_types = _get_dataset_type_names(conns, conns.initOutputs | conns.outputs)
input_types = {
read_edge.dataset_type_name
for read_edge in itertools.chain(task_node.inputs.values(), task_node.init.inputs.values())
}
output_types = {
write_edge.dataset_type_name
for write_edge in itertools.chain(task_node.outputs.values(), task_node.init.outputs.values())
}
all_connection_type_names |= input_types | output_types
# Identify the precursor task: allows appending inject task to subset.
if dataset_type_name in output_types:
precursor_injection_task_labels.add(taskDef.label)
precursor_injection_task_labels.add(task_node.label)
# If the task has any injected dataset type names as inputs, add the
# task to a set of tasks touched by injection, and add all of the
# outputs of this task to the set of injected types.
if len(input_types & injected_types) > 0:
injected_tasks |= {taskDef.label}
injected_tasks |= {task_node.label}
injected_types |= output_types
# Add the injection prefix to all affected dataset type names.
for field in conns.initInputs | conns.inputs | conns.initOutputs | conns.outputs:
if hasattr(taskDef.config.connections.ConnectionsClass, field):
for edge in itertools.chain(
task_node.inputs.values(),
task_node.outputs.values(),
task_node.init.inputs.values(),
task_node.init.outputs.values(),
):
if hasattr(task_node.config.connections.ConnectionsClass, edge.connection_name):
# If the connection type is not dynamic, modify as usual.
if (conn_type := getattr(conns, field).name) in injected_types:
pipeline.addConfigOverride(taskDef.label, "connections." + field, prefix + conn_type)
if edge.parent_dataset_type_name in injected_types:
pipeline.addConfigOverride(
task_node.label,
"connections." + edge.connection_name,
prefix + edge.dataset_type_name,
)
else:
# Add log warning if the connection type is dynamic.
logger.warning(
"Dynamic connection %s in task %s is not supported here. This connection will "
"neither be modified nor merged into the output injection pipeline.",
field,
taskDef.label,
edge.connection_name,
task_node.label,
)
# Raise if the injected dataset type does not exist in the pipeline.
if dataset_type_name not in all_connection_type_names:
Expand Down Expand Up @@ -261,21 +270,18 @@ def make_injection_pipeline(
)
pipeline.mergePipeline(injection_pipeline)
# Loop over all injection tasks and modify the connection names.
for injection_taskDef in injection_pipeline.toExpandedPipeline():
injected_tasks |= {injection_taskDef.label}
conns = injection_taskDef.connections
pipeline.addConfigOverride(
injection_taskDef.label, "connections.input_exposure", dataset_type_name
)
for injection_task_label in injection_pipeline.task_labels:
injected_tasks.add(injection_task_label)
pipeline.addConfigOverride(injection_task_label, "connections.input_exposure", dataset_type_name)
pipeline.addConfigOverride(
injection_taskDef.label, "connections.output_exposure", prefix + dataset_type_name
injection_task_label, "connections.output_exposure", prefix + dataset_type_name
)
# Optionally update subsets to include the injection task.
if not exclude_subsets:
for label in precursor_injection_task_labels:
precursor_subsets = pipeline.findSubsetsWithLabel(label)
for subset in precursor_subsets:
pipeline.addLabelToSubset(subset, injection_taskDef.label)
pipeline.addLabelToSubset(subset, injection_task_label)

# Create injected subsets.
injected_label_specifier = LabelSpecifier(labels=injected_tasks)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ astropy >= 4.0
esutil >= 0.6.9
galsim >= 2.4.6
git+https://github.com/lsst/daf_butler@main#egg=lsst-daf-butler
git+https://github.com/lsst/pipe_base@main#egg=lsst-pipe-base
git+https://github.com/lsst/utils@main#egg=lsst-utils
git+https://github.com/lsst/pipe_base@tickets/DM-40441#egg=lsst-pipe-base
git+https://github.com/lsst/utils@tickets/DM-40441#egg=lsst-utils
git+https://github.com/lsst/resources@main#egg=lsst-resources
git+https://github.com/lsst/pex_config@main#egg=lsst-pex-config
33 changes: 19 additions & 14 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,28 @@ def test_make_injection_pipeline(self):
instrument="lsst.obs.subaru.HyperSuprimeCam",
log_level=logging.DEBUG,
)
expanded_pipeline = merged_pipeline.toExpandedPipeline()
pipeline_graph = merged_pipeline.to_graph()
expected_subset_tasks = ["isr", "inject_exposure", "characterizeImage"]
merged_task_subsets = [merged_pipeline.findSubsetsWithLabel(x) for x in expected_subset_tasks]
self.assertEqual(len(merged_task_subsets), len(expected_subset_tasks))
for taskDef in expanded_pipeline:
conns = taskDef.connections
if taskDef.label == "isr":
self.assertEqual(conns.outputExposure.name, "postISRCCD")
elif taskDef.label == "inject_exposure":
self.assertEqual(conns.input_exposure.name, "postISRCCD")
self.assertEqual(conns.output_exposure.name, "injected_postISRCCD")
self.assertEqual(conns.output_catalog.name, "injected_postISRCCD_catalog")
elif taskDef.label == "characterizeImage":
self.assertEqual(conns.exposure.name, "injected_postISRCCD")
self.assertEqual(conns.characterized.name, "injected_icExp")
self.assertEqual(conns.backgroundModel.name, "injected_icExpBackground")
self.assertEqual(conns.sourceCat.name, "injected_icSrc")
for task_node in pipeline_graph.tasks.values():
if task_node.label == "isr":
self.assertEqual(task_node.outputs["outputExposure"].dataset_type_name, "postISRCCD")
elif task_node.label == "inject_exposure":
self.assertEqual(task_node.inputs["input_exposure"].dataset_type_name, "postISRCCD")
self.assertEqual(
task_node.outputs["output_exposure"].dataset_type_name, "injected_postISRCCD"
)
self.assertEqual(
task_node.outputs["output_catalog"].dataset_type_name, "injected_postISRCCD_catalog"
)
elif task_node.label == "characterizeImage":
self.assertEqual(task_node.inputs["exposure"].dataset_type_name, "injected_postISRCCD")
self.assertEqual(task_node.outputs["characterized"].dataset_type_name, "injected_icExp")
self.assertEqual(
task_node.outputs["backgroundModel"].dataset_type_name, "injected_icExpBackground"
)
self.assertEqual(task_node.outputs["sourceCat"].dataset_type_name, "injected_icSrc")

def test_ingest_injection_catalog(self):
input_dataset_refs = ingest_injection_catalog(
Expand Down

0 comments on commit 2b2b4ad

Please sign in to comment.