Skip to content

Commit

Permalink
feat(pipeline): add built-in list of wait keys
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Dec 11, 2024
1 parent a50ba09 commit de74fe9
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions caput/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,8 +953,9 @@ def _validate_task_inputs(self):
for task_spec in self.task_specs:
in_ = task_spec.get("in", None)
requires = task_spec.get("requires", None)
wait = task_spec.get("wait", None)

for key, value in (["in", in_], ["requires", requires]):
for key, value in (["in", in_], ["requires", requires], ["wait", wait]):
if value is None:
continue
if not isinstance(value, list):
Expand All @@ -970,7 +971,7 @@ def _get_task_from_spec(self, task_spec: dict):
"""Set up a pipeline task from the spec given in the tasks list."""
# Check that only the expected keys are in the task spec.
for key in task_spec.keys():
if key not in ["type", "params", "requires", "in", "out"]:
if key not in ["type", "params", "requires", "in", "out", "wait"]:
raise config.CaputConfigError(
f"Task got an unexpected key '{key}' in 'tasks' list."
)
Expand Down Expand Up @@ -1073,9 +1074,10 @@ def _check_duplicate(key0: str, key1: str, d0: dict, d1: dict):
requires = _check_duplicate("requires", "requires", task_spec, kwargs)
in_ = _check_duplicate("in", "in_", task_spec, kwargs)
out = _check_duplicate("out", "out", task_spec, kwargs)
wait = _check_duplicate("wait", "wait", task_spec, kwargs)

try:
task._setup_keys(in_, out, requires)
task._setup_keys(in_, out, requires, wait)
# Want to blindly catch errors
except Exception as e:
raise config.CaputConfigError(
Expand Down Expand Up @@ -1126,12 +1128,17 @@ class TaskBase(config.Reader):
If true, signals to the pipeline runner to make a call to `breakpoint` each
time this task is run. This will drop the interpreter into pdb, allowing for
interactive debugging of the current pipeline and task state. Default is False.
single_wait : bool
If true, keys in the wait queue only have to be received once, even if `next`
iterates multiple times. Otherwise, `wait` keys must be received prior to
each iteration of `next`. Default is False.
"""

broadcast_inputs = config.Property(proptype=bool, default=False)
limit_outputs = config.Property(proptype=int, default=None)
base_priority = config.Property(proptype=int, default=0)
breakpoint = config.Property(proptype=bool, default=False)
single_wait = config.Property(proptype=bool, default=False)

# Overridable Attributes
# -----------------------
Expand Down Expand Up @@ -1231,6 +1238,13 @@ def _pipeline_is_available(self):
# This task hasn't been initialized
return False

if self._wait is not None and not bool(
min((q.qsize() for q in self._wait), default=1)
):
# If wait flags are required and have not been received,
# this task can't be run
return False

if self._pipeline_state == "setup":
# True if all `requires` items have been provided
# This also returns True is `self._requires` is empty
Expand Down Expand Up @@ -1311,12 +1325,13 @@ def _from_config(cls, config):

return self

def _setup_keys(self, in_=None, out=None, requires=None):
"""Setup the 'requires', 'in' and 'out' keys for this task."""
def _setup_keys(self, in_=None, out=None, requires=None, wait=None):
"""Setup the 'in', 'out', 'requires', and 'wait' keys for this task."""
# Parse the task spec.
requires = _format_product_keys(requires)
in_ = _format_product_keys(in_)
out = _format_product_keys(out)
wait = _format_product_keys(wait)

# Inspect the `setup` method to see how many arguments it takes.
setup_argspec = inspect.getfullargspec(self.setup)
Expand Down Expand Up @@ -1380,6 +1395,11 @@ def _setup_keys(self, in_=None, out=None, requires=None):
# produce multiple values, queue up items which may be used in the
# future
self._in = [queue.Queue() for _ in range(n_in)]
# Store wait keys
self._wait_keys = wait
# Make a list with a queue for each wait key. Use queue because this can
# be buffered similarly to the inputs
self._wait = [queue.Queue() for _ in range(len(wait))]
# Store output keys
self._out_keys = out
# Keep track of the number of times this task has produced output
Expand Down Expand Up @@ -1434,6 +1454,7 @@ def _pipeline_advance_state(self):
)

self._in = None
self._wait = None
self._pipeline_state = "finish"

elif self._pipeline_state == "finish":
Expand Down Expand Up @@ -1476,6 +1497,10 @@ def _pipeline_next(self):
else: # noqa RET506
# Get the next set of data to be run.
args = tuple(in_.get() for in_ in self._in)
# If `wait` flags are not pinned, remove them
# from the queue
if not self.single_wait:
_ = [w.get() for w in self._wait]

# Call the next iteration of `next`. If it is done running,
# advance the task state and continue
Expand Down Expand Up @@ -1548,13 +1573,24 @@ def _pipeline_queue_product(self, key, product):
logger.debug(f"{self!s} stowing data product with key {key} for `in`.")
if self._in is None:
raise PipelineRuntimeError(
"Tried to queue 'in' data product, but `next()` already run."
f"Tried to queue 'in' data product, but task state is `{self._pipeline_state}`."
)

self._in[ii].put(product)

result = True

if key in self._wait_keys:
ii = self._wait_keys.index(key)
logger.debug(f"{self!s} setting wait flag with key {key}.")
if self._wait is None:
raise PipelineRuntimeError(
f"Tried to queue `wait` flag, but task state is `{self._pipeline_state}`."
)
# This data product isn't needed here - just have to record
# that it was received
self._wait[ii].put(True)

return result


Expand Down Expand Up @@ -2089,10 +2125,10 @@ def next(self, in_):
def _format_product_keys(keys):
"""Formats the pipeline task product keys.
In the pipeline config task list, the values of 'requires', 'in' and 'out'
are keys representing data products. This function gets that key from the
task's entry of the task list, defaults to zero, and ensures it's formated
as a sequence of strings.
In the pipeline config task list, the values of 'requires', 'in', 'out' and
'wait' are keys representing data products. This function gets that key
from the task's entry of the task list, defaults to zero, and ensures it's
formated as a sequence of strings.
"""
if keys is None:
return []
Expand Down

0 comments on commit de74fe9

Please sign in to comment.