Skip to content

Commit

Permalink
feat(pipeline): restrict how many times a task can be repeated withou…
Browse files Browse the repository at this point in the history
…t interruption
  • Loading branch information
ljgray committed Dec 21, 2023
1 parent a7a71f7 commit 42e8cfb
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions caput/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ class Manager(config.Reader):
tasks as soon as possible. `tree` walks through the tree of associated input-output
keys and tries to run each branch to completion to minimize the time for which
any given data product will exist. Default is `loop`.
max_repeat : int
Maximum number of times that a task can run in a row before the pipeline will
run another task, even if the current task has top priority. Default is 20.
save_versions : list
Module names (str). This list together with the version strings from these
modules are attached to output metadata. Default is [].
Expand All @@ -481,6 +484,7 @@ class Manager(config.Reader):
cluster = config.Property(default={}, proptype=dict)
task_specs = config.Property(default=[], proptype=list, key="tasks")
execution_order = config.enum(["standard", "legacy"], default="standard")
max_repeat = config.Property(proptype=int, default=20)

# Options to be stored in self.all_tasks_params
versions = config.Property(default=[], proptype=_get_versions, key="save_versions")
Expand Down Expand Up @@ -616,6 +620,9 @@ def run(self):

# Index of first available task in the list
self._task_head = 0
# Number of times the current task has repeated without another task
# running
self._current_task_count = 0

# Choose how to order tasks based on the execution order
if self.execution_order == "legacy":
Expand Down Expand Up @@ -741,14 +748,29 @@ def _next_task(self):
task = sorted(available, key=lambda x: round(x.mem_used / 1e9))[-1]

# Update the task pointer
self._task_idx = self.tasks.index(task)
new_idx = self.tasks.index(task)

if self._mpi_enabled:
# Ensure that all ranks are running the same task. This is
# relevant because a task could be using a different amount
# of memory for different processes, so it's assumed that
# no rank will use _more_ memory than rank 0
self._task_idx = mpiutil.bcast(self._task_idx, root=0)
new_idx = mpiutil.bcast(new_idx, root=0)

# Keep track of how many times in a row this task has been run
if new_idx == self._task_idx:
if self._current_task_count >= self.max_repeat:
# Iterate over all tasks starting at the next one in the
# task order
return self._iter_tasks()
# Otherwise, just increment the counter
self._current_task_count += 1
else:
# Reset the counter since there's a new task
self._current_task_count = 0

# Update the current task pointer
self._task_idx = new_idx

return task

Expand Down

0 comments on commit 42e8cfb

Please sign in to comment.