From 269a16144cee27540064c54c97991f32ecd76634 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Wed, 9 Oct 2024 19:15:56 +0100 Subject: [PATCH 01/11] Extract construct_io_from_annotation function Many places need to inspect a parameter or field to construct a Parameter or Artifact based on the name and other annotations. Extract the single-field logic from _construct_io_from_fields to simplify many callsites that were using get_workflow_annotation and stitching in the name and default Parameter themselves. This will also allow us to reliably provide more features in future, such as inferring an enum parameter for a Literal. Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 18 +++++++++++ src/hera/workflows/_mixins.py | 13 +++----- src/hera/workflows/io/_io_mixins.py | 14 +++------ src/hera/workflows/script.py | 38 +++++++++++------------ tests/test_unit/test_shared_type_utils.py | 32 +++++++++++++++++++ 5 files changed, 77 insertions(+), 38 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index 45f4f9582..ca8ee5301 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -82,6 +82,24 @@ def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Paramet return metadata[0] +def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]: + """Constructs a Parameter or Artifact object based on annotations. + + If a field has a Parameter or Artifact annotation, a copy will be returned, with missing + fields filled out based on other metadata. Otherwise, a Parameter object will be constructed. + + For a function parameter, python_name should be the parameter name. + For a Pydantic Input or Output class, python_name should be the field name. + """ + if annotation := get_workflow_annotation(annotation): + # Copy so as to not modify the fields themselves + annotation_copy = annotation.copy() + annotation_copy.name = annotation.name or python_name + return annotation_copy + else: + return Parameter(name=python_name) + + def get_unsubscripted_type(t: Any) -> Any: """Return the origin of t, if subscripted, or t itself. diff --git a/src/hera/workflows/_mixins.py b/src/hera/workflows/_mixins.py index 9acf61062..d286c8b1d 100644 --- a/src/hera/workflows/_mixins.py +++ b/src/hera/workflows/_mixins.py @@ -17,7 +17,7 @@ from hera.shared import BaseMixin, global_config from hera.shared._pydantic import PrivateAttr, get_field_annotations, get_fields, root_validator, validator -from hera.shared._type_util import get_workflow_annotation +from hera.shared._type_util import construct_io_from_annotation from hera.shared.serialization import serialize from hera.workflows._context import SubNodeMixin, _context from hera.workflows._meta_mixins import CallableTemplateMixin, HeraBuildObj, HookMixin @@ -738,14 +738,9 @@ def __getattribute__(self, name: str) -> Any: result_templated_str = f"{{{{{subnode_type}.{subnode_name}.outputs.result}}}}" return result_templated_str - if param_or_artifact := get_workflow_annotation(annotations[name]): - output_name = param_or_artifact.name or name - if isinstance(param_or_artifact, Parameter): - return "{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{output_name}" + "}}" - else: - return "{{" + f"{subnode_type}.{subnode_name}.outputs.artifacts.{output_name}" + "}}" - - return "{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{name}" + "}}" + param_or_artifact = construct_io_from_annotation(name, annotations[name]) + output_type = "parameters" if isinstance(param_or_artifact, Parameter) else "artifacts" + return "{{" + f"{subnode_type}.{subnode_name}.outputs.{output_type}.{param_or_artifact.name}" + "}}" return super().__getattribute__(name) diff --git a/src/hera/workflows/io/_io_mixins.py b/src/hera/workflows/io/_io_mixins.py index 3ef106122..722d4d8bc 100644 --- a/src/hera/workflows/io/_io_mixins.py +++ b/src/hera/workflows/io/_io_mixins.py @@ -11,7 +11,7 @@ from hera.shared import global_config from hera.shared._global_config import _SUPPRESS_PARAMETER_DEFAULT_ERROR_FLAG from hera.shared._pydantic import _PYDANTIC_VERSION, FieldInfo, get_field_annotations, get_fields -from hera.shared._type_util import get_workflow_annotation +from hera.shared._type_util import construct_io_from_annotation, get_workflow_annotation from hera.shared.serialization import MISSING, serialize from hera.workflows._context import _context from hera.workflows.artifact import Artifact @@ -45,18 +45,12 @@ def _construct_io_from_fields(cls: Type[BaseModel]) -> Iterator[Tuple[str, FieldInfo, Union[Parameter, Artifact]]]: """Constructs a Parameter or Artifact object for all Pydantic fields based on their annotations. - If a field has a Parameter or Artifact annotation, a copy will be returned, with name added if missing. - Otherwise, a Parameter object will be constructed. + If a field has a Parameter or Artifact annotation, a copy will be returned, with missing + fields filled out based on other metadata. Otherwise, a Parameter object will be constructed. """ annotations = get_field_annotations(cls) for field, field_info in get_fields(cls).items(): - if annotation := get_workflow_annotation(annotations[field]): - # Copy so as to not modify the fields themselves - annotation_copy = annotation.copy() - annotation_copy.name = annotation.name or field - yield field, field_info, annotation_copy - else: - yield field, field_info, Parameter(name=field) + yield field, field_info, construct_io_from_annotation(field, annotations[field]) class InputMixin(BaseModel): diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index dce234510..c99030753 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -47,7 +47,12 @@ _flag_enabled, ) from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator -from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass +from hera.shared._type_util import ( + construct_io_from_annotation, + get_workflow_annotation, + is_subscripted, + origin_type_issubclass, +) from hera.shared.serialization import serialize from hera.workflows._context import _context from hera.workflows._meta_mixins import CallableTemplateMixin @@ -434,25 +439,19 @@ def _get_outputs_from_parameter_annotations( artifacts: List[Artifact] = [] for name, p in inspect.signature(source).parameters.items(): - annotation = get_workflow_annotation(p.annotation) - if not annotation or not annotation.output: + annotation = construct_io_from_annotation(name, p.annotation) + if not annotation.output: continue - new_object = annotation.copy() - - # use the function parameter name when not provided by user - if not new_object.name: - new_object.name = name + if isinstance(annotation, Parameter) and annotation.value_from is None and outputs_directory is not None: + annotation.value_from = ValueFrom(path=outputs_directory + f"/parameters/{annotation.name}") + elif isinstance(annotation, Artifact) and annotation.path is None and outputs_directory is not None: + annotation.path = outputs_directory + f"/artifacts/{annotation.name}" - if isinstance(new_object, Parameter) and new_object.value_from is None and outputs_directory is not None: - new_object.value_from = ValueFrom(path=outputs_directory + f"/parameters/{new_object.name}") - elif isinstance(new_object, Artifact) and new_object.path is None and outputs_directory is not None: - new_object.path = outputs_directory + f"/artifacts/{new_object.name}" - - if isinstance(new_object, Artifact): - artifacts.append(new_object) - elif isinstance(new_object, Parameter): - parameters.append(new_object) + if isinstance(annotation, Artifact): + artifacts.append(annotation) + elif isinstance(annotation, Parameter): + parameters.append(annotation) return parameters, artifacts @@ -589,8 +588,9 @@ def _extract_all_output_annotations(source: Callable) -> List: output = [] for _, func_param in inspect.signature(source).parameters.items(): - if (annotated := get_workflow_annotation(func_param.annotation)) and annotated.output: - output.append(annotated) + io = construct_io_from_annotation(func_param.name, func_param.annotation) + if io.output: + output.append(io) output.extend(_extract_return_annotation_output(source)) diff --git a/tests/test_unit/test_shared_type_utils.py b/tests/test_unit/test_shared_type_utils.py index d543997d8..af5f64ff8 100644 --- a/tests/test_unit/test_shared_type_utils.py +++ b/tests/test_unit/test_shared_type_utils.py @@ -4,6 +4,7 @@ from annotated_types import Gt from hera.shared._type_util import ( + construct_io_from_annotation, get_annotated_metadata, get_unsubscripted_type, get_workflow_annotation, @@ -90,6 +91,37 @@ def test_get_workflow_annotation_should_raise_error(annotation): get_workflow_annotation(annotation) +@pytest.mark.parametrize( + "annotation, expected", + [ + [str, Parameter(name="python_name")], + [Annotated[str, Parameter(name="a_str")], Parameter(name="a_str")], + [Annotated[str, Artifact(name="a_str")], Artifact(name="a_str")], + [Annotated[int, Gt(10), Artifact(name="a_int")], Artifact(name="a_int")], + [Annotated[int, Artifact(name="a_int"), Gt(30)], Artifact(name="a_int")], + # this can happen when user uses already annotated types. + [Annotated[Annotated[int, Gt(10)], Artifact(name="a_int")], Artifact(name="a_int")], + ], +) +def test_construct_io_from_annotation(annotation, expected): + assert construct_io_from_annotation("python_name", annotation) == expected + + +@pytest.mark.parametrize( + "annotation", + [ + # Duplicated annotation + Annotated[str, Parameter(name="a_str"), Parameter(name="b_str")], + Annotated[str, Parameter(name="a_str"), Artifact(name="a_str")], + # Nested + Annotated[Annotated[str, Parameter(name="a_str")], Artifact(name="b_str")], + ], +) +def test_construct_io_from_annotation_should_raise_error(annotation): + with pytest.raises(ValueError): + construct_io_from_annotation("python_name", annotation) + + @pytest.mark.parametrize( "annotation, expected", [ From 445195992720cf6f48a62fc981fc562d7d749304 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Wed, 9 Oct 2024 16:29:56 +0100 Subject: [PATCH 02/11] Add three tests for _get_param_items_from_source Add three unit tests of the working current behaviour to prevent regression. Signed-off-by: Alice Purcell --- tests/test_unit/test_meta_mixins.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/test_unit/test_meta_mixins.py diff --git a/tests/test_unit/test_meta_mixins.py b/tests/test_unit/test_meta_mixins.py new file mode 100644 index 000000000..18912baae --- /dev/null +++ b/tests/test_unit/test_meta_mixins.py @@ -0,0 +1,30 @@ +from hera.workflows import Parameter +from hera.workflows._meta_mixins import _get_param_items_from_source + + +def test_get_param_items_from_source_simple_function_one_param(): + def function(some_param: str) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [Parameter(name="some_param", value="{{item}}")] + + +def test_get_param_items_from_source_simple_function_multiple_params(): + def function(foo: str, bar: int, baz: str) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [ + Parameter(name="foo", value="{{item.foo}}"), + Parameter(name="bar", value="{{item.bar}}"), + Parameter(name="baz", value="{{item.baz}}"), + ] + + +def test_get_param_items_from_source_simple_function_defaulted_params_skipped(): + def function(some_param: str, defaulted_param: str = "some value") -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [Parameter(name="some_param", value="{{item}}")] From 02a0c763a12ac12c45e2915537feb53e012f1419 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Wed, 9 Oct 2024 19:31:55 +0100 Subject: [PATCH 03/11] Support IO annotations and with_param Support Parameter/Artifact annotations in _get_param_items_from_source. If a Parameter annotation is provided, the name will override the name of the Python parameter; additionally, if it is an output parameter, or if there is an Artifact annotation, it will be skipped. Signed-off-by: Alice Purcell --- src/hera/workflows/_meta_mixins.py | 14 ++++--- tests/test_unit/test_meta_mixins.py | 60 ++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index d32dd55b6..badef88b8 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -25,7 +25,7 @@ from hera.shared import BaseMixin, global_config from hera.shared._global_config import _DECORATOR_SYNTAX_FLAG, _flag_enabled from hera.shared._pydantic import BaseModel, get_fields, root_validator -from hera.shared._type_util import get_annotated_metadata +from hera.shared._type_util import construct_io_from_annotation, get_annotated_metadata from hera.workflows._context import _context from hera.workflows.exceptions import InvalidTemplateCall from hera.workflows.io.v1 import ( @@ -275,17 +275,19 @@ def _get_param_items_from_source(source: Callable) -> List[Parameter]: List[Parameter] A list of identified parameters (possibly empty). """ - source_signature: List[str] = [] + non_default_parameters: List[Parameter] = [] for p in inspect.signature(source).parameters.values(): if p.default == inspect.Parameter.empty and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: # only add positional or keyword arguments that are not set to a default value # as the default value ones are captured by the automatically generated `Parameter` fields for positional # kwargs. Otherwise, we assume that the user sets the value of the parameter via the `with_param` field - source_signature.append(p.name) + io = construct_io_from_annotation(p.name, p.annotation) + if isinstance(io, Parameter) and io.default is None and not io.output: + non_default_parameters.append(io) - if len(source_signature) == 1: - return [Parameter(name=n, value="{{item}}") for n in source_signature] - return [Parameter(name=n, value=f"{{{{item.{n}}}}}") for n in source_signature] + for param in non_default_parameters: + param.value = "{{" + ("item" if len(non_default_parameters) == 1 else f"item.{param.name}") + "}}" + return non_default_parameters def _get_params_from_items(with_items: List[Any]) -> Optional[List[Parameter]]: diff --git a/tests/test_unit/test_meta_mixins.py b/tests/test_unit/test_meta_mixins.py index 18912baae..8c76f2be7 100644 --- a/tests/test_unit/test_meta_mixins.py +++ b/tests/test_unit/test_meta_mixins.py @@ -1,4 +1,7 @@ -from hera.workflows import Parameter +from pathlib import Path +from typing import Annotated + +from hera.workflows import Artifact, Parameter from hera.workflows._meta_mixins import _get_param_items_from_source @@ -28,3 +31,58 @@ def function(some_param: str, defaulted_param: str = "some value") -> None: ... parameters = _get_param_items_from_source(function) assert parameters == [Parameter(name="some_param", value="{{item}}")] + + +def test_get_param_items_from_source_annotated_function_one_param(): + def function(some_param: Annotated[str, Parameter(name="foobar")]) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [Parameter(name="foobar", value="{{item}}")] + + +def test_get_param_items_from_source_annotated_function_multiple_params(): + def function( + foo: Annotated[str, Parameter(name="foobar")], + bar: int, + baz: Annotated[str, Parameter(description="some description")], + ) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [ + Parameter(name="foobar", value="{{item.foobar}}"), + Parameter(name="bar", value="{{item.bar}}"), + Parameter(name="baz", value="{{item.baz}}", description="some description"), + ] + + +def test_get_param_items_from_source_annotated_function_defaulted_params_skipped(): + def function( + some_param: Annotated[str, Parameter(name="some-param")], + defaulted_param: Annotated[str, Parameter(name="bazbam")] = "some value", + ) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [Parameter(name="some-param", value="{{item}}")] + + +def test_get_param_items_from_source_annotated_function_outputs_skipped(): + def function( + some_param: Annotated[str, Parameter(name="some-param")], output_param: Annotated[Path, Parameter(output=True)] + ) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [Parameter(name="some-param", value="{{item}}")] + + +def test_get_param_items_from_source_annotated_function_artifacts_skipped(): + def function( + some_param: Annotated[str, Parameter(name="some-param")], some_resource: Annotated[str, Artifact()] + ) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [Parameter(name="some-param", value="{{item}}")] From 2f820a0b4d8debb8056ac69a30169620384ac8db Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Thu, 10 Oct 2024 08:16:01 +0100 Subject: [PATCH 04/11] Support Pydantic inputs and with_param Support Pydantic Input subclasses in _get_param_items_from_source. Signed-off-by: Alice Purcell --- src/hera/workflows/_meta_mixins.py | 34 +++++++++++++----- tests/test_unit/test_meta_mixins.py | 54 ++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index badef88b8..e6b637e39 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -25,7 +25,7 @@ from hera.shared import BaseMixin, global_config from hera.shared._global_config import _DECORATOR_SYNTAX_FLAG, _flag_enabled from hera.shared._pydantic import BaseModel, get_fields, root_validator -from hera.shared._type_util import construct_io_from_annotation, get_annotated_metadata +from hera.shared._type_util import construct_io_from_annotation, get_annotated_metadata, unwrap_annotation from hera.workflows._context import _context from hera.workflows.exceptions import InvalidTemplateCall from hera.workflows.io.v1 import ( @@ -263,6 +263,17 @@ def _dispatch_hooks(self: THookable) -> THookable: return output +def _get_pydantic_input_type(source: Callable) -> Union[None, Type[InputV1], Type[InputV2]]: + function_parameters = inspect.signature(source).parameters + if len(function_parameters) != 1: + return None + parameter = next(iter(function_parameters.values())) + parameter_type = unwrap_annotation(parameter.annotation) + if not isinstance(parameter_type, type) or not issubclass(parameter_type, (InputV1, InputV2)): + return None + return parameter_type + + def _get_param_items_from_source(source: Callable) -> List[Parameter]: """Returns a list (possibly empty) of `Parameter` from the specified `source`. @@ -276,14 +287,19 @@ def _get_param_items_from_source(source: Callable) -> List[Parameter]: A list of identified parameters (possibly empty). """ non_default_parameters: List[Parameter] = [] - for p in inspect.signature(source).parameters.values(): - if p.default == inspect.Parameter.empty and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: - # only add positional or keyword arguments that are not set to a default value - # as the default value ones are captured by the automatically generated `Parameter` fields for positional - # kwargs. Otherwise, we assume that the user sets the value of the parameter via the `with_param` field - io = construct_io_from_annotation(p.name, p.annotation) - if isinstance(io, Parameter) and io.default is None and not io.output: - non_default_parameters.append(io) + if pydantic_input := _get_pydantic_input_type(source): + for parameter in pydantic_input._get_parameters(): + if parameter.default is None: + non_default_parameters.append(parameter) + else: + for p in inspect.signature(source).parameters.values(): + if p.default == inspect.Parameter.empty and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + # only add positional or keyword arguments that are not set to a default value + # as the default value ones are captured by the automatically generated `Parameter` fields for positional + # kwargs. Otherwise, we assume that the user sets the value of the parameter via the `with_param` field + io = construct_io_from_annotation(p.name, p.annotation) + if isinstance(io, Parameter) and io.default is None and not io.output: + non_default_parameters.append(io) for param in non_default_parameters: param.value = "{{" + ("item" if len(non_default_parameters) == 1 else f"item.{param.name}") + "}}" diff --git a/tests/test_unit/test_meta_mixins.py b/tests/test_unit/test_meta_mixins.py index 8c76f2be7..feb92cbb8 100644 --- a/tests/test_unit/test_meta_mixins.py +++ b/tests/test_unit/test_meta_mixins.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Annotated -from hera.workflows import Artifact, Parameter +from hera.workflows import Artifact, Input, Parameter from hera.workflows._meta_mixins import _get_param_items_from_source @@ -86,3 +86,55 @@ def function( parameters = _get_param_items_from_source(function) assert parameters == [Parameter(name="some-param", value="{{item}}")] + + +def test_get_param_items_from_source_pydantic_input_one_param(): + class ExampleInput(Input): + some_param: str + + def function(input: ExampleInput) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [Parameter(name="some_param", value="{{item}}")] + + +def test_get_param_items_from_source_pydantic_input_multiple_params(): + class ExampleInput(Input): + foo: Annotated[str, Parameter(name="foobar")] + bar: int + baz: Annotated[str, Parameter(description="some description")] + + def function(input: ExampleInput) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [ + Parameter(name="foobar", value="{{item.foobar}}"), + Parameter(name="bar", value="{{item.bar}}"), + Parameter(name="baz", value="{{item.baz}}", description="some description"), + ] + + +def test_get_param_items_from_source_pydantic_input_defaulted_params_skipped(): + class ExampleInput(Input): + some_param: Annotated[str, Parameter(name="some-param")] + defaulted_param: Annotated[str, Parameter(name="bazbam")] = "some value" + + def function(input: ExampleInput) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [Parameter(name="some-param", value="{{item}}")] + + +def test_get_param_items_from_source_pydantic_input_artifacts_skipped(): + class ExampleInput(Input): + some_param: Annotated[str, Parameter(name="some-param")] + some_resource: Annotated[str, Artifact()] + + def function(input: ExampleInput) -> None: ... + + parameters = _get_param_items_from_source(function) + + assert parameters == [Parameter(name="some-param", value="{{item}}")] From 3e0177ee798d8ab3aa307d31a4908145addcb16c Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Thu, 10 Oct 2024 08:48:22 +0100 Subject: [PATCH 05/11] Add examples reproducing raised issues Add examples reproducing issues #861 (using with_param with an annotated input) and #1234 (using with_param with a Pydantic Input type). Signed-off-by: Alice Purcell --- .../script-annotations-dynamic-fanout.yaml | 60 +++++++++++++++++ .../script-runner-io-dynamic-fanout.yaml | 64 +++++++++++++++++++ .../script_annotations_dynamic_fanout.py | 30 +++++++++ .../script_runner_io_dynamic_fanout.py | 38 +++++++++++ 4 files changed, 192 insertions(+) create mode 100644 examples/workflows/experimental/script-annotations-dynamic-fanout.yaml create mode 100644 examples/workflows/experimental/script-runner-io-dynamic-fanout.yaml create mode 100644 examples/workflows/experimental/script_annotations_dynamic_fanout.py create mode 100644 examples/workflows/experimental/script_runner_io_dynamic_fanout.py diff --git a/examples/workflows/experimental/script-annotations-dynamic-fanout.yaml b/examples/workflows/experimental/script-annotations-dynamic-fanout.yaml new file mode 100644 index 000000000..c3df0b42d --- /dev/null +++ b/examples/workflows/experimental/script-annotations-dynamic-fanout.yaml @@ -0,0 +1,60 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: dynamic-fanout- +spec: + entrypoint: d + templates: + - dag: + tasks: + - name: generate + template: generate + - arguments: + parameters: + - description: this is some value + name: some-value + value: '{{item}}' + depends: generate + name: consume + template: consume + withParam: '{{tasks.generate.outputs.parameters.some-values}}' + name: d + - name: generate + outputs: + parameters: + - name: some-values + valueFrom: + path: /tmp/hera-outputs/parameters/some-values + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.script_annotations_dynamic_fanout:generate + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + image: python:3.9 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - description: this is some value + name: some-value + name: consume + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.script_annotations_dynamic_fanout:consume + command: + - python + env: + - name: hera__script_annotations + value: '' + image: python:3.9 + source: '{{inputs.parameters}}' diff --git a/examples/workflows/experimental/script-runner-io-dynamic-fanout.yaml b/examples/workflows/experimental/script-runner-io-dynamic-fanout.yaml new file mode 100644 index 000000000..0d8237cc1 --- /dev/null +++ b/examples/workflows/experimental/script-runner-io-dynamic-fanout.yaml @@ -0,0 +1,64 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: dynamic-fanout- +spec: + entrypoint: d + templates: + - dag: + tasks: + - name: generate + template: generate + - arguments: + parameters: + - description: this is some value + name: some-value + value: '{{item}}' + depends: generate + name: consume + template: consume + withParam: '{{tasks.generate.outputs.parameters.some-values}}' + name: d + - name: generate + outputs: + parameters: + - name: some-values + valueFrom: + path: /tmp/hera-outputs/parameters/some-values + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.script_runner_io_dynamic_fanout:generate + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.9 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - description: this is some value + name: some-value + name: consume + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.script_runner_io_dynamic_fanout:consume + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__script_pydantic_io + value: '' + image: python:3.9 + source: '{{inputs.parameters}}' diff --git a/examples/workflows/experimental/script_annotations_dynamic_fanout.py b/examples/workflows/experimental/script_annotations_dynamic_fanout.py new file mode 100644 index 000000000..0b5e08d19 --- /dev/null +++ b/examples/workflows/experimental/script_annotations_dynamic_fanout.py @@ -0,0 +1,30 @@ +""" +This example showcases how clients can use Hera to dynamically generate tasks that process outputs from one task in +parallel. This is useful for batch jobs and instances where clients do not know ahead of time how many tasks/entities +they may need to process. +""" + +from typing import Annotated, List + +from hera.shared import global_config +from hera.workflows import DAG, Parameter, Workflow, script + +global_config.experimental_features["script_annotations"] = True + + +@script(constructor="runner") +def generate() -> Annotated[List[int], Parameter(name="some-values")]: + return [i for i in range(10)] + + +@script(constructor="runner") +def consume(some_value: Annotated[int, Parameter(name="some-value", description="this is some value")]): + print("Received value: {value}!".format(value=some_value)) + + +# assumes you used `hera.set_global_token` and `hera.set_global_host` so that the workflow can be submitted +with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w: + with DAG(name="d"): + g = generate(arguments={}) + c = consume(with_param=g.get_parameter("some-values")) + g >> c diff --git a/examples/workflows/experimental/script_runner_io_dynamic_fanout.py b/examples/workflows/experimental/script_runner_io_dynamic_fanout.py new file mode 100644 index 000000000..2f73d803d --- /dev/null +++ b/examples/workflows/experimental/script_runner_io_dynamic_fanout.py @@ -0,0 +1,38 @@ +""" +This example showcases how clients can use Hera to dynamically generate tasks that process outputs from one task in +parallel. This is useful for batch jobs and instances where clients do not know ahead of time how many tasks/entities +they may need to process. +""" + +from typing import Annotated, List + +from hera.shared import global_config +from hera.workflows import DAG, Input, Output, Parameter, Workflow, script + +global_config.experimental_features["script_pydantic_io"] = True + + +class GenerateOutput(Output): + some_values: Annotated[List[int], Parameter(name="some-values")] + + +class ConsumeInput(Input): + some_value: Annotated[int, Parameter(name="some-value", description="this is some value")] + + +@script(constructor="runner") +def generate() -> GenerateOutput: + return GenerateOutput(some_values=[i for i in range(10)]) + + +@script(constructor="runner") +def consume(input: ConsumeInput) -> None: + print("Received value: {value}!".format(value=input.some_value)) + + +# assumes you used `hera.set_global_token` and `hera.set_global_host` so that the workflow can be submitted +with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w: + with DAG(name="d"): + g = generate(arguments={}) + c = consume(with_param=g.get_parameter("some-values")) + g >> c From c0a3b63a9192d392a8ad2321ff204a159edbd0dd Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Thu, 10 Oct 2024 09:43:15 +0100 Subject: [PATCH 06/11] Add docstring Signed-off-by: Alice Purcell --- src/hera/workflows/_meta_mixins.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index e6b637e39..ae99ef293 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -264,6 +264,7 @@ def _dispatch_hooks(self: THookable) -> THookable: def _get_pydantic_input_type(source: Callable) -> Union[None, Type[InputV1], Type[InputV2]]: + """Returns a Pydantic Input type for the source, if it is using Pydantic IO.""" function_parameters = inspect.signature(source).parameters if len(function_parameters) != 1: return None From e5708fe79b75c7d28ed0f32588e0c3afe0d52849 Mon Sep 17 00:00:00 2001 From: Alice Date: Thu, 10 Oct 2024 14:19:11 +0100 Subject: [PATCH 07/11] Remove unnecessary else before return Co-authored-by: Ukjae Jeong Signed-off-by: Alice --- src/hera/shared/_type_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index ca8ee5301..2e4efaa34 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -96,8 +96,8 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par annotation_copy = annotation.copy() annotation_copy.name = annotation.name or python_name return annotation_copy - else: - return Parameter(name=python_name) + + return Parameter(name=python_name) def get_unsubscripted_type(t: Any) -> Any: From 246132668dabd1aaf1a1aea6d4e635ff94afe916 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 14 Oct 2024 14:11:26 +0100 Subject: [PATCH 08/11] Move with_param examples into script_annotations tests Signed-off-by: Alice Purcell --- .../script-annotations-dynamic-fanout.yaml | 60 ----------------- .../script-runner-io-dynamic-fanout.yaml | 64 ------------------- .../pydantic_io_with_param.py | 9 +-- .../script_annotations/with_param.py | 9 +-- tests/test_script_annotations.py | 38 +++++++++++ 5 files changed, 40 insertions(+), 140 deletions(-) delete mode 100644 examples/workflows/experimental/script-annotations-dynamic-fanout.yaml delete mode 100644 examples/workflows/experimental/script-runner-io-dynamic-fanout.yaml rename examples/workflows/experimental/script_runner_io_dynamic_fanout.py => tests/script_annotations/pydantic_io_with_param.py (69%) rename examples/workflows/experimental/script_annotations_dynamic_fanout.py => tests/script_annotations/with_param.py (64%) diff --git a/examples/workflows/experimental/script-annotations-dynamic-fanout.yaml b/examples/workflows/experimental/script-annotations-dynamic-fanout.yaml deleted file mode 100644 index c3df0b42d..000000000 --- a/examples/workflows/experimental/script-annotations-dynamic-fanout.yaml +++ /dev/null @@ -1,60 +0,0 @@ -apiVersion: argoproj.io/v1alpha1 -kind: Workflow -metadata: - generateName: dynamic-fanout- -spec: - entrypoint: d - templates: - - dag: - tasks: - - name: generate - template: generate - - arguments: - parameters: - - description: this is some value - name: some-value - value: '{{item}}' - depends: generate - name: consume - template: consume - withParam: '{{tasks.generate.outputs.parameters.some-values}}' - name: d - - name: generate - outputs: - parameters: - - name: some-values - valueFrom: - path: /tmp/hera-outputs/parameters/some-values - script: - args: - - -m - - hera.workflows.runner - - -e - - examples.workflows.experimental.script_annotations_dynamic_fanout:generate - command: - - python - env: - - name: hera__script_annotations - value: '' - - name: hera__outputs_directory - value: /tmp/hera-outputs - image: python:3.9 - source: '{{inputs.parameters}}' - - inputs: - parameters: - - description: this is some value - name: some-value - name: consume - script: - args: - - -m - - hera.workflows.runner - - -e - - examples.workflows.experimental.script_annotations_dynamic_fanout:consume - command: - - python - env: - - name: hera__script_annotations - value: '' - image: python:3.9 - source: '{{inputs.parameters}}' diff --git a/examples/workflows/experimental/script-runner-io-dynamic-fanout.yaml b/examples/workflows/experimental/script-runner-io-dynamic-fanout.yaml deleted file mode 100644 index 0d8237cc1..000000000 --- a/examples/workflows/experimental/script-runner-io-dynamic-fanout.yaml +++ /dev/null @@ -1,64 +0,0 @@ -apiVersion: argoproj.io/v1alpha1 -kind: Workflow -metadata: - generateName: dynamic-fanout- -spec: - entrypoint: d - templates: - - dag: - tasks: - - name: generate - template: generate - - arguments: - parameters: - - description: this is some value - name: some-value - value: '{{item}}' - depends: generate - name: consume - template: consume - withParam: '{{tasks.generate.outputs.parameters.some-values}}' - name: d - - name: generate - outputs: - parameters: - - name: some-values - valueFrom: - path: /tmp/hera-outputs/parameters/some-values - script: - args: - - -m - - hera.workflows.runner - - -e - - examples.workflows.experimental.script_runner_io_dynamic_fanout:generate - command: - - python - env: - - name: hera__script_annotations - value: '' - - name: hera__outputs_directory - value: /tmp/hera-outputs - - name: hera__script_pydantic_io - value: '' - image: python:3.9 - source: '{{inputs.parameters}}' - - inputs: - parameters: - - description: this is some value - name: some-value - name: consume - script: - args: - - -m - - hera.workflows.runner - - -e - - examples.workflows.experimental.script_runner_io_dynamic_fanout:consume - command: - - python - env: - - name: hera__script_annotations - value: '' - - name: hera__script_pydantic_io - value: '' - image: python:3.9 - source: '{{inputs.parameters}}' diff --git a/examples/workflows/experimental/script_runner_io_dynamic_fanout.py b/tests/script_annotations/pydantic_io_with_param.py similarity index 69% rename from examples/workflows/experimental/script_runner_io_dynamic_fanout.py rename to tests/script_annotations/pydantic_io_with_param.py index 2f73d803d..4d509844b 100644 --- a/examples/workflows/experimental/script_runner_io_dynamic_fanout.py +++ b/tests/script_annotations/pydantic_io_with_param.py @@ -1,9 +1,3 @@ -""" -This example showcases how clients can use Hera to dynamically generate tasks that process outputs from one task in -parallel. This is useful for batch jobs and instances where clients do not know ahead of time how many tasks/entities -they may need to process. -""" - from typing import Annotated, List from hera.shared import global_config @@ -30,9 +24,8 @@ def consume(input: ConsumeInput) -> None: print("Received value: {value}!".format(value=input.some_value)) -# assumes you used `hera.set_global_token` and `hera.set_global_host` so that the workflow can be submitted with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w: - with DAG(name="d"): + with DAG(name="dag"): g = generate(arguments={}) c = consume(with_param=g.get_parameter("some-values")) g >> c diff --git a/examples/workflows/experimental/script_annotations_dynamic_fanout.py b/tests/script_annotations/with_param.py similarity index 64% rename from examples/workflows/experimental/script_annotations_dynamic_fanout.py rename to tests/script_annotations/with_param.py index 0b5e08d19..eb9ab9567 100644 --- a/examples/workflows/experimental/script_annotations_dynamic_fanout.py +++ b/tests/script_annotations/with_param.py @@ -1,9 +1,3 @@ -""" -This example showcases how clients can use Hera to dynamically generate tasks that process outputs from one task in -parallel. This is useful for batch jobs and instances where clients do not know ahead of time how many tasks/entities -they may need to process. -""" - from typing import Annotated, List from hera.shared import global_config @@ -22,9 +16,8 @@ def consume(some_value: Annotated[int, Parameter(name="some-value", description= print("Received value: {value}!".format(value=some_value)) -# assumes you used `hera.set_global_token` and `hera.set_global_host` so that the workflow can be submitted with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w: - with DAG(name="d"): + with DAG(name="dag"): g = generate(arguments={}) c = consume(with_param=g.get_parameter("some-values")) g >> c diff --git a/tests/test_script_annotations.py b/tests/test_script_annotations.py index 0c7e994ea..314269b0a 100644 --- a/tests/test_script_annotations.py +++ b/tests/test_script_annotations.py @@ -438,3 +438,41 @@ def test_script_pydantic_without_experimental_flag(global_config_fixture): "Unable to instantiate since it is an experimental feature." in str(e.value) ) + + +@pytest.mark.parametrize( + "module_name", + [ + "tests.script_annotations.with_param", # annotated types + "tests.script_annotations.pydantic_io_with_param", # Pydantic IO types + ], +) +def test_script_with_param(global_config_fixture, module_name): + """Test that with_param works correctly with annotated/Pydantic IO types.""" + # GIVEN + global_config_fixture.experimental_features["script_annotations"] = True + global_config_fixture.experimental_features["script_pydantic_io"] = True + # Force a reload of the test module, as the runner performs "importlib.import_module", which + # may fetch a cached version + + module = importlib.import_module(module_name) + importlib.reload(module) + workflow = importlib.import_module(module.__name__).w + + # WHEN + workflow_dict = workflow.to_dict() + assert workflow == Workflow.from_dict(workflow_dict) + assert workflow == Workflow.from_yaml(workflow.to_yaml()) + + # THEN + (dag,) = (t for t in workflow_dict["spec"]["templates"] if t["name"] == "dag") + (consume_task,) = (t for t in dag["dag"]["tasks"] if t["name"] == "consume") + + assert consume_task["arguments"]["parameters"] == [ + { + "name": "some-value", + "value": "{{item}}", + "description": "this is some value", + } + ] + assert consume_task["withParam"] == "{{tasks.generate.outputs.parameters.some-values}}" From 4bb789d1c51fd71588d52baf398b7f9613936958 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 14 Oct 2024 14:16:13 +0100 Subject: [PATCH 09/11] Rewrite final for/if in _get_param_items_from_source Signed-off-by: Alice Purcell --- src/hera/workflows/_meta_mixins.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index ae99ef293..8b30cc1d5 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -302,8 +302,11 @@ def _get_param_items_from_source(source: Callable) -> List[Parameter]: if isinstance(io, Parameter) and io.default is None and not io.output: non_default_parameters.append(io) - for param in non_default_parameters: - param.value = "{{" + ("item" if len(non_default_parameters) == 1 else f"item.{param.name}") + "}}" + if len(non_default_parameters) == 1: + non_default_parameters[0].value == "{{item}}" + else: + for param in non_default_parameters: + param.value = "{{item." + param.name + "}}" return non_default_parameters From 12cfa635ebb36bc993895caf9560db8bc2284988 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 14 Oct 2024 14:26:05 +0100 Subject: [PATCH 10/11] Fix lint issue Signed-off-by: Alice Purcell --- src/hera/workflows/_meta_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index 8b30cc1d5..bf765bf41 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -306,7 +306,7 @@ def _get_param_items_from_source(source: Callable) -> List[Parameter]: non_default_parameters[0].value == "{{item}}" else: for param in non_default_parameters: - param.value = "{{item." + param.name + "}}" + param.value = "{{item." + str(param.name) + "}}" return non_default_parameters From ca64ea52c58972774a53f53284345c79b09e00bb Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 14 Oct 2024 14:35:20 +0100 Subject: [PATCH 11/11] Correct == to = Signed-off-by: Alice Purcell --- src/hera/workflows/_meta_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index bf765bf41..162913ea8 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -303,7 +303,7 @@ def _get_param_items_from_source(source: Callable) -> List[Parameter]: non_default_parameters.append(io) if len(non_default_parameters) == 1: - non_default_parameters[0].value == "{{item}}" + non_default_parameters[0].value = "{{item}}" else: for param in non_default_parameters: param.value = "{{item." + str(param.name) + "}}"