Skip to content

Commit

Permalink
Merge pull request #873 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
add sort_by_generation to query_step; support get sub steps of a specific step; fix hooks of task
  • Loading branch information
zjgemi authored Oct 25, 2024
2 parents 1330301 + 67b5876 commit ec67085
Show file tree
Hide file tree
Showing 6 changed files with 516 additions and 9 deletions.
50 changes: 47 additions & 3 deletions src/dflow/argo_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from collections import UserDict, UserList
from copy import deepcopy
from typing import Any, List, Union
from typing import Any, List, Optional, Union

from .common import jsonpickle
from .config import config, s3_config
Expand Down Expand Up @@ -359,6 +359,8 @@ def get_step(
phase: Union[str, List[str]] = None,
id: Union[str, List[str]] = None,
type: Union[str, List[str]] = None,
parent_id: Optional[str] = None,
sort_by_generation: bool = False,
) -> List[ArgoStep]:
if name is not None and not isinstance(name, list):
name = [name]
Expand All @@ -372,7 +374,11 @@ def get_step(
type = [type]
step_list = []
if hasattr(self.status, "nodes"):
for step in self.status.nodes.values():
if parent_id is not None:
nodes = self.get_sub_nodes(parent_id)
else:
nodes = self.status.nodes.values()
for step in nodes:
if step["startedAt"] is None:
continue
if name is not None and not match(step["displayName"], name):
Expand All @@ -395,9 +401,47 @@ def get_step(
continue
step = ArgoStep(step, self.metadata.name)
step_list.append(step)
step_list.sort(key=lambda x: x["startedAt"])
else:
return []
if sort_by_generation:
self.generation = {}
self.record_generation(self.id, 0)
step_list.sort(key=lambda x: self.generation.get(
x["id"], len(self.status.nodes)))
else:
step_list.sort(key=lambda x: x["startedAt"])
return step_list

def get_sub_nodes(self, node_id):
assert node_id in self.status.nodes
node = self.status.nodes[node_id]
if node["type"] not in ["Steps", "DAG"]:
return [node]
if node.get("memoizationStatus", {}).get("hit", False):
return [node]
sub_nodes = []
outbound_nodes = node.get("outboundNodes", [])
children = node.get("children", [])
# order by generation (BFS)
current_generation = children
while len(current_generation) > 0:
for id in current_generation:
sub_nodes.append(self.status.nodes[id])
next_generation = []
for id in current_generation:
if id not in outbound_nodes:
next_generation += self.status.nodes[id].get(
"children", [])
current_generation = next_generation
return sub_nodes

def record_generation(self, node_id, generation):
self.generation[node_id] = generation
for child in self.status.nodes[node_id].get("children", []):
if child in self.generation:
continue
self.record_generation(child, generation+1)

def get_duration(self) -> datetime.timedelta:
return get_duration(self.status)

Expand Down
3 changes: 2 additions & 1 deletion src/dflow/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .v1alpha1_artifact import V1alpha1Artifact
from .v1alpha1_dag_task import V1alpha1DAGTask
from .v1alpha1_lifecycle_hook import V1alpha1LifecycleHook
from .v1alpha1_parameter import V1alpha1Parameter
from .v1alpha1_retry_strategy import V1alpha1RetryStrategy
Expand All @@ -8,4 +9,4 @@

__all__ = ["V1alpha1Artifact", "V1alpha1LifecycleHook", "V1alpha1Parameter",
"V1alpha1RetryStrategy", "V1alpha1Sequence", "V1alpha1ValueFrom",
"V1alpha1WorkflowStep"]
"V1alpha1WorkflowStep", "V1alpha1DAGTask"]
Loading

0 comments on commit ec67085

Please sign in to comment.