diff --git a/caput/pipeline.py b/caput/pipeline.py index 97505864..7239100b 100644 --- a/caput/pipeline.py +++ b/caput/pipeline.py @@ -1,4 +1,4 @@ -"""Data Analysis and Simulation Pipeline. +r"""Data Analysis and Simulation Pipeline. A data analysis pipeline is completely specified by a YAML file that specifies both what tasks are to be run and the parameters that go to those tasks. @@ -170,7 +170,7 @@ requires A 'pipeline product key' or list of keys representing values to be passed as arguments to :meth:`setup`. -in +in\_ A 'pipeline product key' or list of keys representing values to be passed as arguments to :meth:`next`. @@ -180,10 +180,31 @@ Execution Order --------------- -When the above pipeline is executed is produces the following output. +There are two options when choosing how to execute a pipeline: standard and legacy. +When the above pipeline is executed in standard mode, it produces the following output. >>> local_tasks.update(globals()) # Required for interactive sessions. ->>> Manager.from_yaml_str(spam_config).run() +>>> m = Manager.from_yaml_str(spam_config) +>>> m.run() +Setting up PrintEggs. +Setting up GetEggs. +Setting up CookEggs. +Spam and green eggs. +Spam and duck eggs. +Spam and ostrich eggs. +Finished PrintEggs. +Cooking fried green eggs. +Cooking fried duck eggs. +Cooking fried ostrich eggs. +Finished GetEggs. +Finished CookEggs. + +When executed in legacy mode, it produces this output. + +>>> local_tasks.update(globals()) # Required for interactive sessions. +>>> m = Manager.from_yaml_str(spam_config) +>>> m.execution_order = "legacy" +>>> m.run() Setting up PrintEggs. Setting up GetEggs. Setting up CookEggs. @@ -197,11 +218,24 @@ Finished GetEggs. Finished CookEggs. -The rules for execution order are as follows: +To understand the differences, compare the rules for each strategy. +The `standard` method uses a priority system based on the following criteria, +in decreasing importance: + +1. Task must be available to execute some step. +2. Task priority. This is set by two factors: + * Dynamic priority: tasks which have a higher net consumption + (inputs consumed minus outputs created). + * Base priority: user-configurable base priority is added to + the dynamic priority. +3. Pipeline configuration order. + +If no tasks are available to run, the `legacy` method is used, which uses the +following execution order rules: 1. One of the methods `setup()`, `next()` or `finish()`, as appropriate, will be executed from each task, in order. -2. If the task method is missing its input, as specified by the 'requires' or 'in' +2. If the task method is missing its input, as specified by the 'requires' or 'in\_' keys, restart at the beginning of the `tasks` list. 3. If the input to `next()` is missing and the task is at the beginning of the list there will be no opportunity to generate this input. Stop iterating @@ -210,7 +244,13 @@ 5. Once a method from the last member of the `tasks` list is executed, restart at the beginning of the list. -If the above rules seem somewhat opaque, consider the following example which +The difference in outputs is because `PrintEggs` will always have higher priority +than `GetEggs`, so it will run to completion _before_ `GetEggs` starts generating +anything. Only once `PrintEggs` is done will the other tasks run. Even though +`CookEggs` has the highest priority, it cannot do anything without `GetEggs` running +first. + +If the above `legacy` rules seem somewhat opaque, consider the following example which illustrates these rules in a pipeline with a slightly more non-trivial flow. >>> class DoNothing(TaskBase): @@ -252,10 +292,12 @@ ... no_params: {} ... ''' -The following would error, because the pipeline config is checked for errors, like an 'in' parameter without a +The following would error, because the pipeline config is checked for errors, like an 'in\_' parameter without a corresponding 'out':: - Manager.from_yaml_str(new_spam_config).run() + m = Manager.from_yaml_str(new_spam_config) + m.execution_order = "legacy" + m.run() But this is what it would produce otherwise:: @@ -310,6 +352,7 @@ >>> m.add_task(save_output, in_="key1") >>> print_output = Output(lambda x: print("I love %s eggs!" % x)) >>> m.add_task(print_output, in_="key1") +>>> m.execution_order = "legacy" >>> m.run() Setting up CookEggs. Cooking coddled platypus eggs. @@ -348,12 +391,13 @@ import logging import os import queue +import traceback import warnings from copy import deepcopy import yaml -from . import config, fileformats, misc +from . import config, fileformats, misc, mpiutil, tools # Set the module logger. logger = logging.getLogger(__name__) @@ -421,20 +465,23 @@ def _get_versions(modules): modules = [modules] if not isinstance(modules, list): raise config.CaputConfigError( - f"Value of 'save_versions' is of type '{type(modules).__name__}' (expected 'str' or 'list(str)')." + f"Value of 'save_versions' is of type '{type(modules).__name__}' " + "(expected 'str' or 'list(str)')." ) versions = {} for module in modules: if not isinstance(module, str): raise config.CaputConfigError( - f"Found value of type '{type(module).__name__}' in list 'save_versions' (expected 'str')." + f"Found value of type '{type(module).__name__}' in list " + "'save_versions' (expected 'str')." ) try: versions[module] = importlib.import_module(module).__version__ except ModuleNotFoundError as err: raise config.CaputConfigError( - f"Failure getting versions requested with config parameter 'save_versions': {err}" - ) + "Failure getting versions requested with config parameter " + "'save_versions'." + ) from err return versions @@ -454,22 +501,28 @@ class Manager(config.Reader): TODO cluster : dict TODO - tasks : list + task_specs : list Configuration of pipeline tasks. + execution_order : str + Set the task execution order for this pipeline instance. `legacy` round-robins + through all tasks based on the config order, and tries to clear out finished + tasks as soon as possible. `standard` uses a priority and availability system + to select the next task to run, and falls back to `legacy` if nothing is available. save_versions : list Module names (str). This list together with the version strings from these - modules are attached to output metadata. Default: []. + modules are attached to output metadata. Default is []. save_config : bool If this is True, the global pipeline configuration is attached to output - metadata. Default: True. + metadata. Default is `True`. psutil_profiling : bool - Use psutil to profile CPU and memory usage. Default `False`. + Use psutil to profile CPU and memory usage. Default is `False`. """ logging = config.logging_config(default={"root": "WARNING"}) multiprocessing = config.Property(default=1, proptype=int) cluster = config.Property(default={}, proptype=dict) task_specs = config.Property(default=[], proptype=list, key="tasks") + execution_order = config.enum(["standard", "legacy"], default="standard") # Options to be stored in self.all_tasks_params versions = config.Property(default=[], proptype=_get_versions, key="save_versions") @@ -537,7 +590,8 @@ def from_yaml_str(cls, yaml_doc, lint=False, psutil_profiling=False): try: if not isinstance(yaml_params["pipeline"], dict): raise config.CaputConfigError( - f"Value 'pipeline' in YAML configuration is of type '{type(yaml_params['pipeline']).__name__}' (expected a YAML block here).", + "Value 'pipeline' in YAML configuration is of type " + f"`{type(yaml_params['pipeline']).__name__}` (expected a dict here).", location=yaml_params, ) except TypeError as e: @@ -545,6 +599,7 @@ def from_yaml_str(cls, yaml_doc, lint=False, psutil_profiling=False): "Couldn't find key 'pipeline' in YAML configuration document.", location=yaml_params, ) from e + self = cls.from_config( yaml_params["pipeline"], psutil_profiling=psutil_profiling ) @@ -553,6 +608,7 @@ def from_yaml_str(cls, yaml_doc, lint=False, psutil_profiling=False): "versions": self.versions, "pipeline_config": self.all_params if self.save_config else None, } + self._setup_logging(lint) self._setup_tasks() @@ -579,7 +635,7 @@ def _setup_logging(self, lint=False): logging.getLogger(module).setLevel(getattr(logging, level)) def run(self): - """Main driver method for the pipeline. + """Main driver for the pipeline. This function initializes all pipeline tasks and runs the pipeline through to completion. @@ -590,45 +646,172 @@ def run(self): If a task stage returns the wrong number of outputs. """ + from .profile import PSUtilProfiler + + # Log MPI information + if mpiutil._comm is not None: + logger.debug(f"Running with {mpiutil.size} MPI process(es)") + else: + logger.debug("Running in single process without MPI.") + + # Index of first task in the list which has + # not finished running + self._task_head = 0 + # Pointer to next task index + self._task_idx = 0 + + # Choose how to order tasks based on the execution order + next_task = ( + self._iter_tasks if self.execution_order == "legacy" else self._next_task + ) + + logger.debug(f"Using `{self.execution_order}` iteration method.") + # Run the pipeline. - while self.tasks: - for task in list(self.tasks): # Copy list so we can alter it. - # These lines control the flow of the pipeline. - from .profile import PSUtilProfiler + while True: + # Get the next task. `StopIteration` is raised when there are no + # non-None tasks left in the tasks list + try: + task = next_task() + except StopIteration: + # No tasks remaining + break - name_profiling = f"{task.__class__.__name__}.{task._pipeline_state}" + with PSUtilProfiler( + self._psutil_profiling, str(task), logger=getattr(task, "log", logging) + ): + try: + out = task._pipeline_next() + # Raised if either `setup` or `next` was called without + # enough available inputs + except _PipelineMissingData: + # If this is the first task in the task list, it can't receive + # any more inputs and should advance its state + if self._task_idx == self._task_head: + logger.debug( + f"{task!s} missing input data and " + "is at beginning of task list. Advancing state." + ) + task._pipeline_advance_state() + else: + # Restart from the beginning of the task list + self._task_idx = self._task_head + continue + # Raised if the task has finished + except _PipelineFinished: + # Overwrite the task to maintain task list indices + self.tasks[self._task_idx] = None + # Update the first available task index + for ii, t in enumerate(self.tasks[self._task_head :]): + if t is not None: + self._task_head += ii + break + continue - if hasattr(task, "log"): - psutil_log = task.log - else: - psutil_log = logging - with PSUtilProfiler( - self._psutil_profiling, name_profiling, logger=psutil_log - ): - try: - out = task._pipeline_next() - except _PipelineMissingData: - if self.tasks.index(task) == 0: - msg = ( - f"{task.__class__.__name__} missing input data and is at beginning of" - " task list. Advancing state." - ) - logger.debug(msg) - task._pipeline_advance_state() - break - except _PipelineFinished: - self.tasks.remove(task) - continue - # Now pass the output data products to any task that needs them. - out = self._check_task_output(out, task) - if out is None: + if self.execution_order == "legacy": + # Advance the task pointer + self._task_idx += 1 + + # Ensure the output(s) are correctly structured + out = self._check_task_output(out, task) + + if out is None: + continue + + # Queue outputs for any associated tasks + for key, product in zip(task._out_keys, out): + # Purposefully skip this output. Used if only one output + # needs to be passed + if key == "_": continue - keys = str(task._out_keys) - msg = "%s produced output data product with keys %s." - msg = msg % (task.__class__.__name__, keys) - logger.debug(msg) - for receiving_task in self.tasks: - receiving_task._pipeline_inspect_queue_product(task._out_keys, out) + # Try to pass this product to each task + received = [ + recv._pipeline_queue_product(key, product) + for recv in self.tasks + if recv is not None + ] + + if not any(received): + # Just warn. This probably shouldn't happen, but there + # could be some edge cases to deal with + logger.info( + f"Task {task!s} tried to pass key {key} " + "but no task was found to accept it." + ) + + # Pipeline is done + logger.info("FIN") + + def _next_task(self): + """Get the next task to run from the task list. + + Task is chosen based the following criteria: + - able to do something in its current state + - highest priority + - highest base priority + - next in pipeline config order + + If no task is available to do anything, restart from the + pipeline task head. + """ + # Get a list of tasks which are availble to run + available = [] + + for ii in range(len(self.tasks)): + # Loop through tasks starting at the current index. Including + # the current tasks first ensures we clear out completed + # tasks faster + jj = (ii + self._task_idx) % len(self.tasks) + task = self.tasks[jj] + + if task is None: + continue + + if task._pipeline_is_available: + available.append(jj) + + if not available: + # Nothing is currently available, so fall back to a + # blind loop starting at the first task. If there is + # nothing left, this will raise StopIteration. + self._task_idx = self._task_head + return self._iter_tasks() + + # Reverse sort the available tasks first by priority and second + # by base priority such that for any two tasks with equal priority, + # the task with highest base priority will be selected + new_index = sorted( + available, + key=lambda i: (self.tasks[i].priority, self.tasks[i].base_priority), + reverse=True, + )[0] + + # Ensure that all ranks are running the same task. + # This probably should never be needed with the current + # priority selection. Effectively a no-op if no MPI + self._task_idx = mpiutil.bcast(new_index, root=0) + + return self.tasks[self._task_idx] + + def _iter_tasks(self): + """Iterate through tasks in order and return the next in order. + + This method implements the `legacy` execution order, and is used + as a fallback for the `standard` processing order when no task is + available to run. + """ + for ii in range(len(self.tasks)): + # Iterate starting at the next task + jj = (ii + self._task_idx) % len(self.tasks) + task = self.tasks[jj] + + if task is not None: + # Update the task pointer + self._task_idx = jj + return task + + # If all tasks are None, the pipeline is done + raise StopIteration @staticmethod def _check_task_output(out, task): @@ -637,76 +820,95 @@ def _check_task_output(out, task): Returns ------- out : Same as `TaskBase.next` or None - Pipeline product. None if there's no output of that task stage that has to be handled further. + Pipeline product, or None if there is no output of the task stage that + has to be handled further. Raises ------ PipelineRuntimeError If a task stage returns the wrong number of outputs. """ - if out is None: # This iteration supplied no output - return None - - if len(task._out_keys) == 0: # Output not handled by pipeline. + # This iteration supplied no output, or the output is not + # meant to be handled by the pipeline + if out is None or len(task._out_keys) == 0: return None if len(task._out_keys) == 1: - if isinstance(task._out_keys, tuple): - # in config file, written as `out: out_key`, No - # unpacking if `out` is a length 1 sequence. - return (out,) - # `out_keys` is a list. - # In config file, written as `out: [out_key,]`. + # if tuple, in config file written as `out: out_key`, No + # unpacking if `out` is a length 1 sequence. If list, + # in config file written as `out: [out_key,]`. # `out` must be a length 1 sequence. - return out + if isinstance(task._out_keys, tuple): + if not isinstance(out, tuple): + out = (out,) - if len(task._out_keys) != len(out): + elif len(task._out_keys) != len(out): raise PipelineRuntimeError( - f"Found unexpected number of outputs in {task.__class__.__name__}" + f"Found unexpected number of outputs in {task!s} " f"(got {len(out)} expected {len(task._out_keys)})" ) + logger.debug( + f"{task!s} produced output data product with keys {task._out_keys!s}" + ) + return out def _setup_tasks(self): """Create and setup all tasks from the task list.""" - all_out_values = {t.get("out", None) for t in self.task_specs} + # Validate that all inputs have a corresponding output key. + self._validate_task_inputs() # Setup all tasks in the task listk for ii, task_spec in enumerate(self.task_specs): try: - task, key_spec = self._setup_task(task_spec) - requires = key_spec.get("requires", None) - in_ = key_spec.get("in", None) - out = key_spec.get("out", None) - self._validate_task(task, in_, requires, all_out_values) - self.add_task( - task, - requires=requires, - in_=in_, - out=out, - ) + # Load the task instance and add it to the pipeline + task = self._get_task_from_spec(task_spec) + self.add_task(task, task_spec) except config.CaputConfigError as e: - msg = f"Setting up task {ii} caused an error:\n\t{e!s}" raise config.CaputConfigError( - msg, location=task_spec if e.line is None else e.line + f"Setting up task {ii} caused an error:\n\t{traceback.format_exc()}", + location=task_spec if e.line is None else e.line, ) from e - @staticmethod - def _validate_task(task, in_, requires, all_out_values): - # Make sure this tasks in/requires values have corresponding out keys from another task - for key, value in (["in", in_], ["requires", requires]): - if value is not None: + def _validate_task_inputs(self): + # Make sure all tasks' in/requires values have corresponding + # out keys from another task + all_out_values = [] + for t in self.task_specs: + if "out" in t: + if isinstance(t["out"], (list, tuple)): + all_out_values.extend(t["out"]) + else: + all_out_values.append(t["out"]) + + unique_out_values = set(all_out_values) + + # Multiple tasks produce output with the same key + if len(unique_out_values) != len(all_out_values): + dup_keys = [k for k in unique_out_values if all_out_values.count(k) > 1] + raise config.CaputConfigError( + f"Duplicate output keys: outputs {dup_keys} were found " + "to come from multiple tasks." + ) + + for task_spec in self.task_specs: + in_ = task_spec.get("in", None) + requires = task_spec.get("requires", None) + + for key, value in (["in", in_], ["requires", requires]): + if value is None: + continue if not isinstance(value, list): value = [value] for v in value: - if v not in all_out_values: + if v not in unique_out_values: raise config.CaputConfigError( - f"Value '{key}' for task {type(task)} has no corresponding 'out' from another task " - f"(Value {v} is not in {all_out_values})." + f"Value '{key}' for task {task_spec['type']} has no corresponding " + f"`out` from another task (Value {v} is not in {unique_out_values})." ) - def _setup_task(self, task_spec): + def _get_task_from_spec(self, task_spec: dict): """Set up a pipeline task from the spec given in the tasks list.""" # Check that only the expected keys are in the task spec. for key in task_spec.keys(): @@ -729,8 +931,9 @@ def _setup_task(self, task_spec): try: task_cls = misc.import_class(task_path) except (config.CaputConfigError, AttributeError, ModuleNotFoundError) as e: - msg = f"Loading task '{task_path}' caused error {e.__class__.__name__}:\n\t{e!s}" - raise config.CaputConfigError(msg) from e + raise config.CaputConfigError( + f"Loading task `{task_path}` caused an error:\n\t{traceback.format_exc()}" + ) from e # Get the parameters and initialize the class. params = {} @@ -752,55 +955,77 @@ def _setup_task(self, task_spec): try: params.update(self.all_params[param_key]) except KeyError as e: - msg = f"Parameter group {param_key} not found in config." - raise config.CaputConfigError(msg) from e + raise config.CaputConfigError( + f"Parameter group {param_key} not found in config." + ) from e # add global params to params task_params = deepcopy(self.all_tasks_params) task_params.update(params) - # Filter just the specifications for the input/output keys - key_spec = { - k: v for k, v in task_spec.items() if k in ["requires", "in", "out"] - } - # Create and configure the task instance try: task = task_cls._from_config(task_params) except config.CaputConfigError as e: raise config.CaputConfigError( - f"Failed instantiating {task_cls} from config:\n\t{e}", + f"Failed instantiating {task_cls} from config.\n\t{traceback.format_exc()}", location=task_spec.get("params", task_spec), ) from e - return task, key_spec + return task - def add_task(self, task, requires=None, in_=None, out=None): - """Add a task instance to the pipeline. + def add_task(self, task, task_spec: dict = {}, **kwargs): + r"""Add a task instance to the pipeline. Parameters ---------- task : TaskBase A pipeline task instance. - requires, in_, out : list or string + task_spec : dict + include optional argument: requires, in\_, out : list or string The names of the task inputs and outputs. + **kwargs : dict + Included for legacy purposes. Alternative method to provide + `requires`, `in\_`, and `out` arguments. These should *only* + be provided if `task_spec` is not provided - a ValueError + will be raised otherwise. Raises ------ caput.config.CaputConfigError If there was an error in the task configuration. """ + + def _check_duplicate(key0: str, key1: str, d0: dict, d1: dict): + """Check if an argument has been provided twice.""" + val0 = d0.get(key0, d0.get(key1)) + val1 = d1.get(key0, d1.get(key1)) + + # Check that the key has not been provided twice. It's + # ok to return None, we only care if *both* values are + # not None + if val0 is None: + return val1 + + if val1 is None: + return val0 + + raise ValueError(f"Argument `{key0}/{key1}` was provided twice") + + requires = _check_duplicate("requires", "requires", task_spec, kwargs) + in_ = _check_duplicate("in", "in_", task_spec, kwargs) + out = _check_duplicate("out", "out", task_spec, kwargs) + try: - task._setup_keys(requires=requires, in_=in_, out=out) + task._setup_keys(in_, out, requires) + # Want to blindly catch errors except Exception as e: - msg = f"Setting up keys for task {task.__class__.__name__} caused an error:\n\t{e!s}" - raise config.CaputConfigError(msg) from e - - # The tasks own custom validation method - task.validate() + raise config.CaputConfigError( + f"Adding task {task!s} caused an error:\n\t{traceback.format_exc()}" + ) from e self.tasks.append(task) - logger.debug(f"Added {task.__class__.__name__} to task list.") + logger.debug(f"Added {task!s} to task list.") # Pipeline Task Base Classes @@ -819,8 +1044,32 @@ class TaskBase(config.Reader): pipeline yaml file when the pipeline is initialized. The class attributes will be overridden with instance attributes with the same name but with the values specified in the pipeline file. + + Attributes + ---------- + broadcast_inputs : bool + If true, input queues will be broadcast to process all combinations of + entries. Otherwise, items in input queues are removed at equal rate. + NOT CURRENTLY IMPLEMENTED + limit_outputs : int + Limits the number of `next` outputs from this task before finishing. + Default is None, allowing an unlimited number of `next` products. + base_priority : int + Base integer priority. Priority only matters relative to other tasks + in a pipeline, with run order given by `sorted(priorities, reverse=True)`. + Task priority is also adjusted based on net difference in input and output, + which will typically adjust priority by +/- (0 to 2). `base_priority` should + be set accordingly - factors of 10 (i.e. -10, 10, 20, ...) are effective at + forcing a task to have highest/lowest priority relative to other tasks. + `base_priority` should be used sparingly when a user wants to enforce a + specific non-standard pipeline behaviour. See method `priority` for details + about dynamic priority. """ + broadcast_inputs = config.Property(proptype=bool, default=False) + limit_outputs = config.Property(proptype=int, default=None) + base_priority = config.Property(proptype=int, default=0) + # Overridable Attributes # ----------------------- @@ -833,6 +1082,15 @@ def __init__(self): """ pass + def __str__(self): + """Clean string representation of the task and its state. + + If no state has been set yet, the state is None. + """ + state = getattr(self, "_pipeline_state", None) + + return f"{self.__class__.__name__}.{state}" + def setup(self, requires=None): """First analysis stage of pipeline task. @@ -903,71 +1161,176 @@ def cacheable(self): # Pipeline Infrastructure # ----------------------- + @property + def _pipeline_is_available(self): + """True if this task can be run.""" + if not hasattr(self, "_pipeline_state"): + # This task hasn't been initialized + return False + + if self._pipeline_state == "setup": + # True if all `requires` items have been provided + # This also returns True is `self._requires` is empty + return all(r is not None for r in self._requires) + + if self._pipeline_state == "next": + # True if there is at least one input available + # in each input queue. + return bool(min((q.qsize() for q in self._in), default=0)) + + # Otherwise, this task is likely done and can be run to + # see if anything else happens + return True + + @property + def priority(self): + """Return the priority associated with this task. + + If the task is not yet initialized, dynamic priority is zero. + + If the task in in state `setup`, dynamic priority is one + if all `requires` items are stashed and zero otherwise. + + If the task is in state `next`, dynamic priority is the total + net consumption of the task. + + For example: + - A task which consumes 2 items, produces one, and can currently run + once will have priority (2 - 1) * 1 + base = 1 + base + - A task which does not consume anything but produces one item + will have priority (0 - 1) * 1 + base = -1 + base + + In any other state, priority is just net consumption for one + iteration. + + The priority returned is the sum of `base_priority` and the + calculated dynamic priority. + + Returns + ------- + priority : int + `base_priority` plus dynamic priority calculated based on + task state and inputs/outputs + """ + if not hasattr(self, "_pipeline_state"): + # This task hasn't been initialized + p = 0 + + elif self._pipeline_state == "setup": + # 1 if all requirements are available or no requirements, + # zero if requirements are needed but not available + p = int(all(r is not None for r in self._requires)) + + elif self._pipeline_state == "next": + # Calculate the total net consumption of the task + p = len(self._in_keys) - len(self._out_keys) + # How many times can the task run? + p *= min((q.qsize() for q in self._in), default=1) + + else: + # If a task has passed the above states, it should be + # finished quickly so set a very high priority + p = 1e10 + + return p + self.base_priority + + @property + def mem_used(self): + """Return the approximate total memory referenced by this task.""" + return tools.total_size(self) + @classmethod def _from_config(cls, config): self = cls.__new__(cls) # Check for unused keys, but ignore the ones not put there by the user. self.read_config(config, compare_keys=["versions", "pipeline_config"]) self.__init__() + return self def _setup_keys(self, in_=None, out=None, requires=None): """Setup the 'requires', 'in' and 'out' keys for this task.""" - # Put pipeline in state such that `setup` is the next stage called. - self._pipeline_advance_state() # Parse the task spec. requires = _format_product_keys(requires) in_ = _format_product_keys(in_) out = _format_product_keys(out) + # Inspect the `setup` method to see how many arguments it takes. setup_argspec = inspect.getfullargspec(self.setup) - # Make sure it matches `requires` keys list specified in config. - n_requires = len(requires) + try: len_defaults = len(setup_argspec.defaults) - except TypeError: # defaults is None + # defaults is None + except TypeError: len_defaults = 0 + min_req = len(setup_argspec.args) - len_defaults - 1 + + # Make sure it matches `requires` keys list specified in config. + n_requires = len(requires) + if n_requires < min_req: - msg = ( - "Didn't get enough 'requires' keys. Expected at least" - " %d and only got %d." % (min_req, n_requires) + raise config.CaputConfigError( + "Didn't get enough 'requires' keys. Expected at least " + f"{min_req} and only got {n_requires}." ) - raise config.CaputConfigError(msg) + if n_requires > len(setup_argspec.args) - 1 and setup_argspec.varargs is None: - msg = "Got too many 'requires' keys. Expected at most %d and" " got %d." % ( - len(setup_argspec.args) - 1, - n_requires, + raise config.CaputConfigError( + "Got too many 'requires' keys. Expected at most " + f"{len(setup_argspec.args) - 1} and got {n_requires}." ) - raise config.CaputConfigError(msg) + # Inspect the `next` method to see how many arguments it takes. next_argspec = inspect.getfullargspec(self.next) - # Make sure it matches `in` keys list specified in config. - n_in = len(in_) + try: len_defaults = len(next_argspec.defaults) except TypeError: # defaults is None len_defaults = 0 + min_in = len(next_argspec.args) - len_defaults - 1 + + # Make sure it matches `in` keys list specified in config. + n_in = len(in_) + if n_in < min_in: - msg = ( - "Didn't get enough 'in' keys. Expected at least" - " %d and only got %d." % (min_in, n_in) + raise config.CaputConfigError( + "Didn't get enough 'in' keys. Expected at least " + f"{min_in} and only got {n_in}." ) - raise config.CaputConfigError(msg) + if n_in > len(next_argspec.args) - 1 and next_argspec.varargs is None: - msg = "Got too many 'in' keys. Expected at most %d and" " got %d." % ( - len(next_argspec.args) - 1, - n_in, + raise config.CaputConfigError( + "Got too many 'in' keys. Expected at most " + f"{len(next_argspec.args) - 1} and got {n_in}." ) - raise config.CaputConfigError(msg) + # Now that all data product keys have been verified to be valid, store # them on the instance. self._requires_keys = requires + # Set up a list with the number of required entries self._requires = [None] * n_requires + # Store input keys self._in_keys = in_ - self._in = [queue.Queue() for i in range(n_in)] + # Make a list with one queue for each input. Since any given input can + # produce multiple values, queue up items which may be used in the + # future + self._in = [queue.Queue() for _ in range(n_in)] + # Store output keys self._out_keys = out + # Keep track of the number of times this task has produced output + self._num_iters = 0 + + if self.broadcast_inputs: + # Additional queues to help manage inputs when broadcasting + # self._bcast_queue = [queue.Queue() for _ in range(n_in)] + raise NotImplementedError + + # Do any extra validation here + self.validate() + # Put pipeline in state such that `setup` is the next stage called. + self._pipeline_advance_state() def _pipeline_advance_state(self): """Advance this pipeline task to the next stage. @@ -980,27 +1343,42 @@ def _pipeline_advance_state(self): """ if not hasattr(self, "_pipeline_state"): self._pipeline_state = "setup" + elif self._pipeline_state == "setup": - # Delete inputs to free memory. - self._requires = None + # Advance the state to `next` self._pipeline_state = "next" + # Make sure setup received all input. If not, go straight to + # `finish`, because some input requirement was never generated. + for req, req_key in zip(self._requires, self._requires_keys): + if req is None: + warnings.warn( + f"Task {self!s} tried to advance to `next` " + f"without completing `setup`. Input `{req_key}` was never received. " + "Advancing to `finish`." + ) + self._pipeline_state = "finish" + # Overwrite inputs to free memory. + self._requires = None + elif self._pipeline_state == "next": # Make sure input queues are empty then delete them so no more data # can be queued. for in_, in_key in zip(self._in, self._in_keys): if not in_.empty(): - msg = ( - f"Task finished {self.__class__.__name__} iterating `next()` " - f"but input queue '{in_key}' isn't empty." + warnings.warn( + f"Task {self!s} finished iterating `next()` " + f"but input queue `{in_key}` isn't empty." ) - warnings.warn(msg) self._in = None self._pipeline_state = "finish" + elif self._pipeline_state == "finish": self._pipeline_state = "raise" + elif self._pipeline_state == "raise": pass + else: raise PipelineRuntimeError() @@ -1016,10 +1394,12 @@ def _pipeline_next(self): for req in self._requires: if req is None: raise _PipelineMissingData() - msg = f"Task {self.__class__.__name__} calling 'setup()'." - logger.debug(msg) + + logger.debug(f"Task {self!s} calling 'setup()'.") + out = self.setup(*tuple(self._requires)) self._pipeline_advance_state() + return out if self._pipeline_state == "next": @@ -1027,74 +1407,90 @@ def _pipeline_next(self): for in_ in self._in: if in_.empty(): raise _PipelineMissingData() - # Get the next set of data to be run. - args = () - for in_ in self._in: - args += (in_.get(),) + + if self.broadcast_inputs: + raise NotImplementedError + else: # noqa RET506 + # Get the next set of data to be run. + args = tuple(in_.get() for in_ in self._in) + + # Call the next iteration of `next`. If it is done running, + # advance the task state and continue + logger.debug(f"Task {self!s} calling 'next()'.") + try: - msg = f"Task {self.__class__.__name__} calling 'next()'." - logger.debug(msg) - return self.next(*args) + out = self.next(*args) except PipelineStopIteration: # Finished iterating `next()`. self._pipeline_advance_state() - return None + out = None + + if out is not None: + self._num_iters += 1 + # If this task has a restricted number of outputs, it should advance + # if enough output iterations have been executed + if ( + self.limit_outputs is not None + and self._num_iters >= self.limit_outputs + ): + logger.info( + f"Task {self!s} reached maximum number of output " + f"iterations ({self.limit_outputs}). Advancing state." + ) + self._pipeline_advance_state() + + return out + + if self._pipeline_state == "finish": + logger.debug(f"Task {self!s} calling 'finish()'.") - elif self._pipeline_state == "finish": - msg = f"Task {self.__class__.__name__} calling 'finish()'." - logger.debug(msg) out = self.finish() self._pipeline_advance_state() + return out - elif self._pipeline_state == "raise": + + if self._pipeline_state == "raise": raise _PipelineFinished() - else: - raise PipelineRuntimeError() - def _pipeline_inspect_queue_product(self, keys, products): - """Inspect data products and queue them as inputs if applicable. + raise PipelineRuntimeError() + + def _pipeline_queue_product(self, key, product): + """Put a product into an input queue as applicable. - Compare a list of data products keys to the keys expected by this task - as inputs to `setup()` ('requires') and `next()` ('in'). If there is a - match, store the corresponding data product to be used in the next - invocation of these methods. + Add a product to either a `requires` slot or an input queue based + on the associated key. """ - n_keys = len(keys) - for ii in range(n_keys): - key = keys[ii] - product = products[ii] - for jj, requires_key in enumerate(self._requires_keys): - if requires_key == key: - # Make sure that `setup()` hasn't already been run or this - # data product already set. - msg = "%s stowing data product with key %s for 'requires'." - msg = msg % (self.__class__.__name__, key) - logger.debug(msg) - if self._requires is None: - msg = ( - "Tried to set 'requires' data product, but" - "`setup()` already run." - ) - raise PipelineRuntimeError(msg) - if self._requires[jj] is not None: - msg = "'requires' data product set more than once." - raise PipelineRuntimeError(msg) - # Accept the data product and store for later use. - self._requires[jj] = product - for jj, in_key in enumerate(self._in_keys): - if in_key == key: - msg = "%s queue data product with key %s for 'in'." - msg = msg % (self.__class__.__name__, key) - logger.debug(msg) - # Check that task is still accepting inputs. - if self._in is None: - msg = ( - "Tried to queue 'requires' data product, but" - "`next()` iteration already completed." - ) - raise PipelineRuntimeError(msg) - # Accept the data product and store for later use. - self._in[jj].put(product) + # First, check requires keys + if key in self._requires_keys: + ii = self._requires_keys.index(key) + logger.debug( + f"{self!s} stowing data product with key {key} for `requires`." + ) + if self._requires is None: + raise PipelineRuntimeError( + "Tried to set 'requires' data product, but `setup()` already run." + ) + if self._requires[ii] is not None: + raise PipelineRuntimeError( + "'requires' data product set more than once." + ) + self._requires[ii] = product + + return True + + if key in self._in_keys: + ii = self._in_keys.index(key) + logger.debug(f"{self!s} stowing data product with key {key} for `in`.") + if self._in is None: + raise PipelineRuntimeError( + "Tried to queue 'in' data product, but `next()` already run." + ) + + self._in[ii].put(product) + + return True + + return False class _OneAndOne(TaskBase):