Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with with_param #1236

Merged
merged 14 commits into from
Oct 25, 2024
Merged
18 changes: 18 additions & 0 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@
return metadata[0]


def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]:
"""Constructs a Parameter or Artifact object based on annotations.

If a field has a Parameter or Artifact annotation, a copy will be returned, with missing
fields filled out based on other metadata. Otherwise, a Parameter object will be constructed.

For a function parameter, python_name should be the parameter name.
For a Pydantic Input or Output class, python_name should be the field name.
"""
if annotation := get_workflow_annotation(annotation):
# Copy so as to not modify the fields themselves
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or python_name
return annotation_copy

Check warning on line 98 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L96-L98

Added lines #L96 - L98 were not covered by tests

return Parameter(name=python_name)

Check warning on line 100 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L100

Added line #L100 was not covered by tests


def get_unsubscripted_type(t: Any) -> Any:
"""Return the origin of t, if subscripted, or t itself.

Expand Down
46 changes: 34 additions & 12 deletions src/hera/workflows/_meta_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from hera.shared import BaseMixin, global_config
from hera.shared._global_config import _DECORATOR_SYNTAX_FLAG, _flag_enabled
from hera.shared._pydantic import BaseModel, get_fields, root_validator
from hera.shared._type_util import get_annotated_metadata
from hera.shared._type_util import construct_io_from_annotation, get_annotated_metadata, unwrap_annotation
from hera.workflows._context import _context
from hera.workflows.exceptions import InvalidTemplateCall
from hera.workflows.io.v1 import (
Expand Down Expand Up @@ -263,6 +263,18 @@
return output


def _get_pydantic_input_type(source: Callable) -> Union[None, Type[InputV1], Type[InputV2]]:
"""Returns a Pydantic Input type for the source, if it is using Pydantic IO."""
function_parameters = inspect.signature(source).parameters

Check warning on line 268 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L268

Added line #L268 was not covered by tests
if len(function_parameters) != 1:
return None
parameter = next(iter(function_parameters.values()))
parameter_type = unwrap_annotation(parameter.annotation)

Check warning on line 272 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L270-L272

Added lines #L270 - L272 were not covered by tests
if not isinstance(parameter_type, type) or not issubclass(parameter_type, (InputV1, InputV2)):
return None
return parameter_type

Check warning on line 275 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L274-L275

Added lines #L274 - L275 were not covered by tests


def _get_param_items_from_source(source: Callable) -> List[Parameter]:
"""Returns a list (possibly empty) of `Parameter` from the specified `source`.

Expand All @@ -275,17 +287,27 @@
List[Parameter]
A list of identified parameters (possibly empty).
"""
source_signature: List[str] = []
for p in inspect.signature(source).parameters.values():
if p.default == inspect.Parameter.empty and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
# only add positional or keyword arguments that are not set to a default value
# as the default value ones are captured by the automatically generated `Parameter` fields for positional
# kwargs. Otherwise, we assume that the user sets the value of the parameter via the `with_param` field
source_signature.append(p.name)

if len(source_signature) == 1:
return [Parameter(name=n, value="{{item}}") for n in source_signature]
return [Parameter(name=n, value=f"{{{{item.{n}}}}}") for n in source_signature]
non_default_parameters: List[Parameter] = []

Check warning on line 290 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L290

Added line #L290 was not covered by tests
if pydantic_input := _get_pydantic_input_type(source):
for parameter in pydantic_input._get_parameters():
if parameter.default is None:
non_default_parameters.append(parameter)

Check warning on line 294 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L294

Added line #L294 was not covered by tests
else:
for p in inspect.signature(source).parameters.values():
if p.default == inspect.Parameter.empty and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
# only add positional or keyword arguments that are not set to a default value
# as the default value ones are captured by the automatically generated `Parameter` fields for positional
# kwargs. Otherwise, we assume that the user sets the value of the parameter via the `with_param` field
io = construct_io_from_annotation(p.name, p.annotation)

Check warning on line 301 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L301

Added line #L301 was not covered by tests
if isinstance(io, Parameter) and io.default is None and not io.output:
alicederyn marked this conversation as resolved.
Show resolved Hide resolved
non_default_parameters.append(io)

Check warning on line 303 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L303

Added line #L303 was not covered by tests

if len(non_default_parameters) == 1:
non_default_parameters[0].value = "{{item}}"

Check warning on line 306 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L306

Added line #L306 was not covered by tests
else:
for param in non_default_parameters:
param.value = "{{item." + str(param.name) + "}}"
return non_default_parameters

Check warning on line 310 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L309-L310

Added lines #L309 - L310 were not covered by tests


def _get_params_from_items(with_items: List[Any]) -> Optional[List[Parameter]]:
Expand Down
13 changes: 4 additions & 9 deletions src/hera/workflows/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from hera.shared import BaseMixin, global_config
from hera.shared._pydantic import PrivateAttr, get_field_annotations, get_fields, root_validator, validator
from hera.shared._type_util import get_workflow_annotation
from hera.shared._type_util import construct_io_from_annotation
from hera.shared.serialization import serialize
from hera.workflows._context import SubNodeMixin, _context
from hera.workflows._meta_mixins import CallableTemplateMixin, HeraBuildObj, HookMixin
Expand Down Expand Up @@ -738,14 +738,9 @@
result_templated_str = f"{{{{{subnode_type}.{subnode_name}.outputs.result}}}}"
return result_templated_str

if param_or_artifact := get_workflow_annotation(annotations[name]):
output_name = param_or_artifact.name or name
if isinstance(param_or_artifact, Parameter):
return "{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{output_name}" + "}}"
else:
return "{{" + f"{subnode_type}.{subnode_name}.outputs.artifacts.{output_name}" + "}}"

return "{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{name}" + "}}"
param_or_artifact = construct_io_from_annotation(name, annotations[name])
output_type = "parameters" if isinstance(param_or_artifact, Parameter) else "artifacts"
return "{{" + f"{subnode_type}.{subnode_name}.outputs.{output_type}.{param_or_artifact.name}" + "}}"

Check warning on line 743 in src/hera/workflows/_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_mixins.py#L741-L743

Added lines #L741 - L743 were not covered by tests

return super().__getattribute__(name)

Expand Down
14 changes: 4 additions & 10 deletions src/hera/workflows/io/_io_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from hera.shared import global_config
from hera.shared._global_config import _SUPPRESS_PARAMETER_DEFAULT_ERROR_FLAG
from hera.shared._pydantic import _PYDANTIC_VERSION, FieldInfo, get_field_annotations, get_fields
from hera.shared._type_util import get_workflow_annotation
from hera.shared._type_util import construct_io_from_annotation, get_workflow_annotation
from hera.shared.serialization import MISSING, serialize
from hera.workflows._context import _context
from hera.workflows.artifact import Artifact
Expand Down Expand Up @@ -45,18 +45,12 @@
def _construct_io_from_fields(cls: Type[BaseModel]) -> Iterator[Tuple[str, FieldInfo, Union[Parameter, Artifact]]]:
"""Constructs a Parameter or Artifact object for all Pydantic fields based on their annotations.

If a field has a Parameter or Artifact annotation, a copy will be returned, with name added if missing.
Otherwise, a Parameter object will be constructed.
If a field has a Parameter or Artifact annotation, a copy will be returned, with missing
fields filled out based on other metadata. Otherwise, a Parameter object will be constructed.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I changed this docstring because the core logic is now in a different function, so it will be easy to miss changes in future and bitrot. I intend to change the behaviour as part of fixing #1173, for instance, to set the enum field for Literals.

"""
annotations = get_field_annotations(cls)
for field, field_info in get_fields(cls).items():
if annotation := get_workflow_annotation(annotations[field]):
# Copy so as to not modify the fields themselves
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or field
yield field, field_info, annotation_copy
else:
yield field, field_info, Parameter(name=field)
yield field, field_info, construct_io_from_annotation(field, annotations[field])

Check warning on line 53 in src/hera/workflows/io/_io_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/io/_io_mixins.py#L53

Added line #L53 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍



class InputMixin(BaseModel):
Expand Down
38 changes: 19 additions & 19 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@
_flag_enabled,
)
from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator
from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issupertype
from hera.shared._type_util import (
construct_io_from_annotation,
get_workflow_annotation,
is_subscripted,
origin_type_issupertype,
)
from hera.shared.serialization import serialize
from hera.workflows._context import _context
from hera.workflows._meta_mixins import CallableTemplateMixin
Expand Down Expand Up @@ -434,25 +439,19 @@
artifacts: List[Artifact] = []

for name, p in inspect.signature(source).parameters.items():
annotation = get_workflow_annotation(p.annotation)
if not annotation or not annotation.output:
annotation = construct_io_from_annotation(name, p.annotation)

Check warning on line 442 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L442

Added line #L442 was not covered by tests
Comment on lines 441 to +442

This comment was marked as resolved.

if not annotation.output:
continue

new_object = annotation.copy()

# use the function parameter name when not provided by user
if not new_object.name:
new_object.name = name
if isinstance(annotation, Parameter) and annotation.value_from is None and outputs_directory is not None:
annotation.value_from = ValueFrom(path=outputs_directory + f"/parameters/{annotation.name}")

Check warning on line 447 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L447

Added line #L447 was not covered by tests
elif isinstance(annotation, Artifact) and annotation.path is None and outputs_directory is not None:
annotation.path = outputs_directory + f"/artifacts/{annotation.name}"

Check warning on line 449 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L449

Added line #L449 was not covered by tests

if isinstance(new_object, Parameter) and new_object.value_from is None and outputs_directory is not None:
new_object.value_from = ValueFrom(path=outputs_directory + f"/parameters/{new_object.name}")
elif isinstance(new_object, Artifact) and new_object.path is None and outputs_directory is not None:
new_object.path = outputs_directory + f"/artifacts/{new_object.name}"

if isinstance(new_object, Artifact):
artifacts.append(new_object)
elif isinstance(new_object, Parameter):
parameters.append(new_object)
if isinstance(annotation, Artifact):
artifacts.append(annotation)

Check warning on line 452 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L452

Added line #L452 was not covered by tests
elif isinstance(annotation, Parameter):
parameters.append(annotation)

Check warning on line 454 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L454

Added line #L454 was not covered by tests

return parameters, artifacts

Expand Down Expand Up @@ -591,8 +590,9 @@
output = []

for _, func_param in inspect.signature(source).parameters.items():
if (annotated := get_workflow_annotation(func_param.annotation)) and annotated.output:
output.append(annotated)
io = construct_io_from_annotation(func_param.name, func_param.annotation)

Check warning on line 593 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L593

Added line #L593 was not covered by tests
if io.output:
output.append(io)

Check warning on line 595 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L595

Added line #L595 was not covered by tests

output.extend(_extract_return_annotation_output(source))

Expand Down
31 changes: 31 additions & 0 deletions tests/script_annotations/pydantic_io_with_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Annotated, List

from hera.shared import global_config
from hera.workflows import DAG, Input, Output, Parameter, Workflow, script

global_config.experimental_features["script_pydantic_io"] = True


class GenerateOutput(Output):
some_values: Annotated[List[int], Parameter(name="some-values")]


class ConsumeInput(Input):
some_value: Annotated[int, Parameter(name="some-value", description="this is some value")]


@script(constructor="runner")
def generate() -> GenerateOutput:
return GenerateOutput(some_values=[i for i in range(10)])


@script(constructor="runner")
def consume(input: ConsumeInput) -> None:
print("Received value: {value}!".format(value=input.some_value))


with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w:
with DAG(name="dag"):
g = generate(arguments={})
c = consume(with_param=g.get_parameter("some-values"))
g >> c
23 changes: 23 additions & 0 deletions tests/script_annotations/with_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Annotated, List

from hera.shared import global_config
from hera.workflows import DAG, Parameter, Workflow, script

global_config.experimental_features["script_annotations"] = True


@script(constructor="runner")
def generate() -> Annotated[List[int], Parameter(name="some-values")]:
return [i for i in range(10)]


@script(constructor="runner")
def consume(some_value: Annotated[int, Parameter(name="some-value", description="this is some value")]):
print("Received value: {value}!".format(value=some_value))


with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w:
with DAG(name="dag"):
g = generate(arguments={})
c = consume(with_param=g.get_parameter("some-values"))
g >> c
38 changes: 38 additions & 0 deletions tests/test_script_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,41 @@ def test_script_pydantic_without_experimental_flag(global_config_fixture):
"Unable to instantiate <class 'tests.script_annotations.pydantic_io_v1.ParamOnlyInput'> since it is an experimental feature."
in str(e.value)
)


@pytest.mark.parametrize(
"module_name",
[
"tests.script_annotations.with_param", # annotated types
"tests.script_annotations.pydantic_io_with_param", # Pydantic IO types
],
)
def test_script_with_param(global_config_fixture, module_name):
"""Test that with_param works correctly with annotated/Pydantic IO types."""
# GIVEN
global_config_fixture.experimental_features["script_annotations"] = True
global_config_fixture.experimental_features["script_pydantic_io"] = True
# Force a reload of the test module, as the runner performs "importlib.import_module", which
# may fetch a cached version

module = importlib.import_module(module_name)
importlib.reload(module)
workflow = importlib.import_module(module.__name__).w

# WHEN
workflow_dict = workflow.to_dict()
assert workflow == Workflow.from_dict(workflow_dict)
assert workflow == Workflow.from_yaml(workflow.to_yaml())

# THEN
(dag,) = (t for t in workflow_dict["spec"]["templates"] if t["name"] == "dag")
(consume_task,) = (t for t in dag["dag"]["tasks"] if t["name"] == "consume")

assert consume_task["arguments"]["parameters"] == [
{
"name": "some-value",
"value": "{{item}}",
"description": "this is some value",
}
]
assert consume_task["withParam"] == "{{tasks.generate.outputs.parameters.some-values}}"
Loading
Loading