Skip to content

Commit

Permalink
Bug fixes for input Parameters and Artifacts in annotations (#811)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [x] Fixes #809, fixes #810
- [x] Tests added
- [x] Documentation/examples added
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Currently, output parameters can be specified without a name within the
function parameters, but input parameters cannot. This PR makes the
runner use the function parameter name as the `name` of a
`hera.workflows.Parameter`.

Input Artifacts can be specified without a path in the annotation, which
creates the correct yaml, seen in #792, but at runtime the runner would
not know what path the artifact is supposed to load from. Fixed and
tested.

This PR also cleans up the tests directory structure, removing a lot of
duplicated code for runner/annotation tests. The distinction between
these is now made clear in docstrings for `test_runner.py` and
`test_script_annotations.py` - the runner tests should run the actual
script function code to emulate the Argo cluster, while the script
annotation tests check that function definitions produce correct yaml
and are valid Python when using the annotations.

---------

Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton authored Oct 24, 2023
1 parent 49ba997 commit b5a260f
Show file tree
Hide file tree
Showing 29 changed files with 602 additions and 563 deletions.
10 changes: 5 additions & 5 deletions docs/examples/workflows/callable_script.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@


@script()
def function_kebab_object(input_values: Annotated[Input, Parameter(name="input-values")]) -> Output:
return Output(output=[input_values])
def function_kebab_object(annotated_input_value: Annotated[Input, Parameter(name="input-value")]) -> Output:
return Output(output=[annotated_input_value])


with Workflow(name="my-workflow") as w:
Expand All @@ -98,7 +98,7 @@
str_function(arguments={"input": Input(a=2, b="bar").json()})
another_function(arguments={"inputs": [Input(a=2, b="bar"), Input(a=2, b="bar")]})
function_kebab(arguments={"a-but-kebab": 3, "b-but-kebab": "bar"})
function_kebab_object(arguments={"input-values": Input(a=3, b="bar")})
function_kebab_object(arguments={"input-value": Input(a=3, b="bar")})
```

=== "YAML"
Expand Down Expand Up @@ -140,7 +140,7 @@
template: function-kebab
- - arguments:
parameters:
- name: input-values
- name: input-value
value: '{"a": 3, "b": "bar"}'
name: function-kebab-object
template: function-kebab-object
Expand Down Expand Up @@ -217,7 +217,7 @@
source: '{{inputs.parameters}}'
- inputs:
parameters:
- name: input-values
- name: input-value
name: function-kebab-object
script:
args:
Expand Down
4 changes: 2 additions & 2 deletions examples/workflows/callable-script.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ spec:
template: function-kebab
- - arguments:
parameters:
- name: input-values
- name: input-value
value: '{"a": 3, "b": "bar"}'
name: function-kebab-object
template: function-kebab-object
Expand Down Expand Up @@ -111,7 +111,7 @@ spec:
source: '{{inputs.parameters}}'
- inputs:
parameters:
- name: input-values
- name: input-value
name: function-kebab-object
script:
args:
Expand Down
6 changes: 3 additions & 3 deletions examples/workflows/callable_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def function_kebab(


@script()
def function_kebab_object(input_values: Annotated[Input, Parameter(name="input-values")]) -> Output:
return Output(output=[input_values])
def function_kebab_object(annotated_input_value: Annotated[Input, Parameter(name="input-value")]) -> Output:
return Output(output=[annotated_input_value])


with Workflow(name="my-workflow") as w:
Expand All @@ -88,4 +88,4 @@ def function_kebab_object(input_values: Annotated[Input, Parameter(name="input-v
str_function(arguments={"input": Input(a=2, b="bar").json()})
another_function(arguments={"inputs": [Input(a=2, b="bar"), Input(a=2, b="bar")]})
function_kebab(arguments={"a-but-kebab": 3, "b-but-kebab": "bar"})
function_kebab_object(arguments={"input-values": Input(a=3, b="bar")})
function_kebab_object(arguments={"input-value": Input(a=3, b="bar")})
9 changes: 8 additions & 1 deletion src/hera/workflows/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

_DEFAULT_ARTIFACT_INPUT_DIRECTORY = "/tmp/hera/inputs/artifacts/"


class ArtifactLoader(Enum):
"""Enum for artifact loader options."""
Expand Down Expand Up @@ -78,7 +80,9 @@ class Artifact(BaseModel):
"""allows the specification of an artifact from a subpath within the main source."""

loader: Optional[ArtifactLoader] = None
"""used in Artifact annotations for determining how to load the data"""
"""used for input Artifact annotations for determining how to load the data.
Note: A loader value of 'None' must be used with an underlying type of 'str' or Path-like class."""

output: bool = False
"""used to specify artifact as an output in function signature annotations"""
Expand All @@ -87,6 +91,9 @@ def _check_name(self):
if not self.name:
raise ValueError("name cannot be `None` or empty when used")

def _get_default_inputs_path(self) -> str:
return _DEFAULT_ARTIFACT_INPUT_DIRECTORY + f"{self.name}"

def _build_archive(self) -> Optional[_ModelArchiveStrategy]:
if self.archive is None:
return None
Expand Down
113 changes: 69 additions & 44 deletions src/hera/workflows/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,47 +105,67 @@ def _is_output_kwarg(key, f):
)


def _map_keys(function: Callable, kwargs: dict) -> dict:
"""Change the kwargs's keys to use the Python name instead of the parameter name which could be kebab case.
For Parameters, update their name to not contain kebab-case in Python but allow it in YAML.
For Artifacts, load the Artifact according to the given ArtifactLoader.
def _map_argo_inputs_to_function(function: Callable, kwargs: Dict) -> Dict:
"""Map kwargs from Argo to the function parameters using the function's parameter annotations.
For Parameter inputs:
* if the Parameter has a "name", replace it with the function parameter name
* otherwise use the function parameter name as-is
For Parameter outputs:
* update value to a Path object from the value_from.path value, or the default if not provided
For Artifact inputs:
* load the Artifact according to the given ArtifactLoader
For Artifact outputs:
* update value to a Path object
"""
if os.environ.get("hera__script_annotations", None) is None:
return {key.replace("-", "_"): value for key, value in kwargs.items()}

mapped_kwargs: Dict[str, Any] = {}

def map_annotated_param(param_name: str, param_annotation: Parameter) -> None:
if param_annotation.output:
if param_annotation.value_from and param_annotation.value_from.path:
mapped_kwargs[param_name] = Path(param_annotation.value_from.path)
else:
mapped_kwargs[param_name] = _get_outputs_path(param_annotation)
# Automatically create the parent directory (if required)
mapped_kwargs[param_name].parent.mkdir(parents=True, exist_ok=True)
elif param_annotation.name:
mapped_kwargs[param_name] = kwargs[param_annotation.name]
else:
mapped_kwargs[param_name] = kwargs[param_name]

def map_annotated_artifact(param_name: str, artifact_annotation: Artifact) -> None:
if artifact_annotation.output:
if artifact_annotation.path:
mapped_kwargs[param_name] = Path(artifact_annotation.path)
else:
mapped_kwargs[param_name] = _get_outputs_path(artifact_annotation)
# Automatically create the parent directory (if required)
mapped_kwargs[param_name].parent.mkdir(parents=True, exist_ok=True)
else:
if not artifact_annotation.path:
# Path was added to yaml automatically, we need to add it back in for the runner
artifact_annotation.path = artifact_annotation._get_default_inputs_path()

if artifact_annotation.loader == ArtifactLoader.json.value:
path = Path(artifact_annotation.path)
mapped_kwargs[param_name] = json.load(path.open())
elif artifact_annotation.loader == ArtifactLoader.file.value:
path = Path(artifact_annotation.path)
mapped_kwargs[param_name] = path.read_text()
elif artifact_annotation.loader is None:
mapped_kwargs[param_name] = artifact_annotation.path

for param_name, func_param in inspect.signature(function).parameters.items():
if get_origin(func_param.annotation) is Annotated:
annotated_type = get_args(func_param.annotation)[1]

if isinstance(annotated_type, Parameter):
if annotated_type.output:
if annotated_type.value_from and annotated_type.value_from.path:
mapped_kwargs[param_name] = Path(annotated_type.value_from.path)
else:
mapped_kwargs[param_name] = _get_outputs_path(annotated_type)
# Automatically create the parent directory (if required)
mapped_kwargs[param_name].parent.mkdir(parents=True, exist_ok=True)
else:
mapped_kwargs[param_name] = kwargs[annotated_type.name]
elif isinstance(annotated_type, Artifact):
if annotated_type.output:
if annotated_type.path:
mapped_kwargs[param_name] = Path(annotated_type.path)
else:
mapped_kwargs[param_name] = _get_outputs_path(annotated_type)
# Automatically create the parent directory (if required)
mapped_kwargs[param_name].parent.mkdir(parents=True, exist_ok=True)
elif annotated_type.path:
if annotated_type.loader == ArtifactLoader.json.value:
path = Path(annotated_type.path)
mapped_kwargs[param_name] = json.load(path.open())
elif annotated_type.loader == ArtifactLoader.file.value:
path = Path(annotated_type.path)
mapped_kwargs[param_name] = path.read_text()
elif annotated_type.loader is None:
mapped_kwargs[param_name] = annotated_type.path
func_param_annotation = get_args(func_param.annotation)[1]

if isinstance(func_param_annotation, Parameter):
map_annotated_param(param_name, func_param_annotation)
elif isinstance(func_param_annotation, Artifact):
map_annotated_artifact(param_name, func_param_annotation)
else:
mapped_kwargs[param_name] = kwargs[param_name]
else:
mapped_kwargs[param_name] = kwargs[param_name]

Expand Down Expand Up @@ -212,14 +232,12 @@ def _write_to_path(path: Path, output_value: Any) -> None:
path.write_text(output_string)


def _runner(entrypoint: str, kwargs_list: Any) -> Any:
"""Run a function with a list of kwargs.
def _runner(entrypoint: str, kwargs_list: List) -> Any:
"""Run the function defined by the entrypoint with the given list of kwargs.
Args:
entrypoint: The path to the script within the container to execute.
module: The module path to import the function from.
function_name: The name of the function to run.
kwargs_list: A list of kwargs to pass to the function.
entrypoint: The module path to the script within the container to execute. "package.submodule:function"
kwargs_list: A list of dicts with "name" and "value" keys, representing the kwargs of the function.
Returns:
The result of the function or `None` if the outputs are to be saved.
Expand All @@ -240,7 +258,13 @@ def _runner(entrypoint: str, kwargs_list: Any) -> Any:
key = cast(str, serialize(kwarg["name"]))
value = kwarg["value"]
kwargs[key] = value
kwargs = _map_keys(function, kwargs)

if os.environ.get("hera__script_annotations", None) is None:
# Do a simple replacement for hyphens to get valid Python parameter names.
kwargs = {key.replace("-", "_"): value for key, value in kwargs.items()}
else:
kwargs = _map_argo_inputs_to_function(function, kwargs)

function = validate_arguments(function)
function = _ignore_unmatched_kwargs(function)

Expand Down Expand Up @@ -278,6 +302,7 @@ def _run():
# 2. Protect against files containing `null` as text with outer `or []` (as a result of using
# `{{inputs.parameters}}` where the parameters key doesn't exist in `inputs`)
kwargs_list = json.loads(args.args_path.read_text() or r"[]") or []
assert isinstance(kwargs_list, List)
result = _runner(args.entrypoint, kwargs_list)
if not result:
return
Expand Down
12 changes: 4 additions & 8 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@
from typing_extensions import Annotated # type: ignore


_DEFAULT_ARTIFACT_INPUT_DIRECTORY = "/tmp/hera/inputs/artifacts/"


class ScriptConstructor(BaseMixin):
"""A ScriptConstructor is responsible for generating the source code for a Script given a python callable.
Expand Down Expand Up @@ -415,7 +412,9 @@ def _get_inputs_from_callable(source: Callable) -> Tuple[List[Parameter], List[A
artifacts = []

for func_param in inspect.signature(source).parameters.values():
if get_origin(func_param.annotation) is not Annotated:
if get_origin(func_param.annotation) is not Annotated or not isinstance(
get_args(func_param.annotation)[1], (Artifact, Parameter)
):
if (
func_param.default != inspect.Parameter.empty
and func_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
Expand All @@ -428,9 +427,6 @@ def _get_inputs_from_callable(source: Callable) -> Tuple[List[Parameter], List[A
else:
annotation = get_args(func_param.annotation)[1]

if not isinstance(annotation, (Artifact, Parameter)):
raise ValueError(f"The output {type(annotation)} cannot be used as an annotation.")

if annotation.output:
continue

Expand All @@ -441,7 +437,7 @@ def _get_inputs_from_callable(source: Callable) -> Tuple[List[Parameter], List[A

if isinstance(new_object, Artifact):
if new_object.path is None:
new_object.path = _DEFAULT_ARTIFACT_INPUT_DIRECTORY + f"{new_object.name}"
new_object.path = new_object._get_default_inputs_path()

artifacts.append(new_object)
elif isinstance(new_object, Parameter):
Expand Down
12 changes: 0 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from pathlib import Path
from shutil import rmtree

import pytest

Expand All @@ -19,13 +17,3 @@ def environ_annotations_fixture():
os.environ["hera__script_annotations"] = ""
yield
del os.environ["hera__script_annotations"]


@pytest.fixture
def tmp_path_fixture():
# create a temporary directory
path = Path("test_outputs")
path.mkdir(exist_ok=True)
yield path
# destroy the directory
rmtree(path)
8 changes: 0 additions & 8 deletions tests/helper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1 @@
ARTIFACT_PATH = "/tmp/file"


def my_function(a: int, b: str):
return a + b


def no_param_function():
return "Hello there!"
Loading

0 comments on commit b5a260f

Please sign in to comment.