diff --git a/caput/pipeline.py b/caput/pipeline.py index 010b7ccd..fbd09c60 100644 --- a/caput/pipeline.py +++ b/caput/pipeline.py @@ -953,8 +953,9 @@ def _validate_task_inputs(self): for task_spec in self.task_specs: in_ = task_spec.get("in", None) requires = task_spec.get("requires", None) + wait = task_spec.get("wait", None) - for key, value in (["in", in_], ["requires", requires]): + for key, value in (["in", in_], ["requires", requires], ["wait", wait]): if value is None: continue if not isinstance(value, list): @@ -970,7 +971,7 @@ def _get_task_from_spec(self, task_spec: dict): """Set up a pipeline task from the spec given in the tasks list.""" # Check that only the expected keys are in the task spec. for key in task_spec.keys(): - if key not in ["type", "params", "requires", "in", "out"]: + if key not in ["type", "params", "requires", "in", "out", "wait"]: raise config.CaputConfigError( f"Task got an unexpected key '{key}' in 'tasks' list." ) @@ -1073,9 +1074,10 @@ def _check_duplicate(key0: str, key1: str, d0: dict, d1: dict): requires = _check_duplicate("requires", "requires", task_spec, kwargs) in_ = _check_duplicate("in", "in_", task_spec, kwargs) out = _check_duplicate("out", "out", task_spec, kwargs) + wait = _check_duplicate("wait", "wait", task_spec, kwargs) try: - task._setup_keys(in_, out, requires) + task._setup_keys(in_, out, requires, wait) # Want to blindly catch errors except Exception as e: raise config.CaputConfigError( @@ -1126,12 +1128,17 @@ class TaskBase(config.Reader): If true, signals to the pipeline runner to make a call to `breakpoint` each time this task is run. This will drop the interpreter into pdb, allowing for interactive debugging of the current pipeline and task state. Default is False. + single_wait : bool + If true, keys in the wait queue only have to be received once, even if `next` + iterates multiple times. Otherwise, `wait` keys must be received prior to + each iteration of `next`. Default is False. """ broadcast_inputs = config.Property(proptype=bool, default=False) limit_outputs = config.Property(proptype=int, default=None) base_priority = config.Property(proptype=int, default=0) breakpoint = config.Property(proptype=bool, default=False) + single_wait = config.Property(proptype=bool, default=False) # Overridable Attributes # ----------------------- @@ -1231,6 +1238,13 @@ def _pipeline_is_available(self): # This task hasn't been initialized return False + if self._wait is not None and not bool( + min((q.qsize() for q in self._wait), default=1) + ): + # If wait flags are required and have not been received, + # this task can't be run + return False + if self._pipeline_state == "setup": # True if all `requires` items have been provided # This also returns True is `self._requires` is empty @@ -1311,12 +1325,13 @@ def _from_config(cls, config): return self - def _setup_keys(self, in_=None, out=None, requires=None): - """Setup the 'requires', 'in' and 'out' keys for this task.""" + def _setup_keys(self, in_=None, out=None, requires=None, wait=None): + """Setup the 'in', 'out', 'requires', and 'wait' keys for this task.""" # Parse the task spec. requires = _format_product_keys(requires) in_ = _format_product_keys(in_) out = _format_product_keys(out) + wait = _format_product_keys(wait) # Inspect the `setup` method to see how many arguments it takes. setup_argspec = inspect.getfullargspec(self.setup) @@ -1380,6 +1395,11 @@ def _setup_keys(self, in_=None, out=None, requires=None): # produce multiple values, queue up items which may be used in the # future self._in = [queue.Queue() for _ in range(n_in)] + # Store wait keys + self._wait_keys = wait + # Make a list with a queue for each wait key. Use queue because this can + # be buffered similarly to the inputs + self._wait = [queue.Queue() for _ in range(len(wait))] # Store output keys self._out_keys = out # Keep track of the number of times this task has produced output @@ -1434,6 +1454,7 @@ def _pipeline_advance_state(self): ) self._in = None + self._wait = None self._pipeline_state = "finish" elif self._pipeline_state == "finish": @@ -1476,6 +1497,10 @@ def _pipeline_next(self): else: # noqa RET506 # Get the next set of data to be run. args = tuple(in_.get() for in_ in self._in) + # If `wait` flags are not pinned, remove them + # from the queue + if not self.single_wait: + _ = [w.get() for w in self._wait] # Call the next iteration of `next`. If it is done running, # advance the task state and continue @@ -1548,13 +1573,24 @@ def _pipeline_queue_product(self, key, product): logger.debug(f"{self!s} stowing data product with key {key} for `in`.") if self._in is None: raise PipelineRuntimeError( - "Tried to queue 'in' data product, but `next()` already run." + f"Tried to queue 'in' data product, but task state is `{self._pipeline_state}`." ) self._in[ii].put(product) result = True + if key in self._wait_keys: + ii = self._wait_keys.index(key) + logger.debug(f"{self!s} setting wait flag with key {key}.") + if self._wait is None: + raise PipelineRuntimeError( + f"Tried to queue `wait` flag, but task state is `{self._pipeline_state}`." + ) + # This data product isn't needed here - just have to record + # that it was received + self._wait[ii].put(True) + return result @@ -2089,10 +2125,10 @@ def next(self, in_): def _format_product_keys(keys): """Formats the pipeline task product keys. - In the pipeline config task list, the values of 'requires', 'in' and 'out' - are keys representing data products. This function gets that key from the - task's entry of the task list, defaults to zero, and ensures it's formated - as a sequence of strings. + In the pipeline config task list, the values of 'requires', 'in', 'out' and + 'wait' are keys representing data products. This function gets that key + from the task's entry of the task list, defaults to zero, and ensures it's + formated as a sequence of strings. """ if keys is None: return []