From 52437bec0ab479ce1347d7df459ffa3a0b10c66b Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Fri, 11 Oct 2024 12:13:10 +0100 Subject: [PATCH 1/6] Fix deserialization for str unions PR #1168 changed the logic of map_runner_input and _parse to no longer pass incoming values to json.loads if their annotated type was a union that included str, rather than only when given a subtype of `Optional[str]`. Split `origin_type_issubclass` into two functions, `origin_type_issupertype` (which matches the previous behaviour) and `origin_type_issubtype`, and use the latter instead to restore the original behaviour. Add a runner check which verifies this behaviour. Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 17 +++++-- .../_runner/script_annotations_util.py | 10 +++- src/hera/workflows/_runner/util.py | 10 +++- src/hera/workflows/script.py | 6 ++- tests/script_runner/parameter_inputs.py | 7 ++- tests/test_runner.py | 12 +++++ tests/test_unit/test_shared_type_utils.py | 48 ++++++++++++++----- 7 files changed, 87 insertions(+), 23 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index 45f4f9582..de2e0ab9b 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -92,15 +92,24 @@ def get_unsubscripted_type(t: Any) -> Any: return t -def origin_type_issubclass(cls: Any, type_: type) -> bool: - """Return True if cls can be considered as a subclass of type_.""" - unwrapped_type = unwrap_annotation(cls) +def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) -> bool: + """Return True if annotation is a subtype of type_.""" + unwrapped_type = unwrap_annotation(annotation) origin_type = get_unsubscripted_type(unwrapped_type) if origin_type is Union or origin_type is UnionType: - return any(origin_type_issubclass(arg, type_) for arg in get_args(cls)) + return all(origin_type_issubtype(arg, type_) for arg in get_args(annotation)) return issubclass(origin_type, type_) +def origin_type_issupertype(annotation: Any, type_: type) -> bool: + """Return True if annotation is a supertype of type_.""" + unwrapped_type = unwrap_annotation(annotation) + origin_type = get_unsubscripted_type(unwrapped_type) + if origin_type is Union or origin_type is UnionType: + return any(origin_type_issupertype(arg, type_) for arg in get_args(annotation)) + return issubclass(type_, origin_type) + + def is_subscripted(t: Any) -> bool: """Check if given type is subscripted, i.e. a typing object of the form X[Y, Z, ...]. diff --git a/src/hera/workflows/_runner/script_annotations_util.py b/src/hera/workflows/_runner/script_annotations_util.py index fbdf162ae..d560fe4dd 100644 --- a/src/hera/workflows/_runner/script_annotations_util.py +++ b/src/hera/workflows/_runner/script_annotations_util.py @@ -3,15 +3,21 @@ import inspect import json import os +import sys from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast +if sys.version_info >= (3, 10): + from types import NoneType +else: + NoneType = type(None) + from hera.shared._pydantic import BaseModel, get_field_annotations, get_fields from hera.shared._type_util import ( get_unsubscripted_type, get_workflow_annotation, is_subscripted, - origin_type_issubclass, + origin_type_issubtype, unwrap_annotation, ) from hera.shared.serialization import serialize @@ -138,7 +144,7 @@ def map_runner_input( input_model_obj = {} def load_parameter_value(value: str, value_type: type) -> Any: - if origin_type_issubclass(value_type, str): + if origin_type_issubtype(value_type, (str, NoneType)): return value try: diff --git a/src/hera/workflows/_runner/util.py b/src/hera/workflows/_runner/util.py index 5acc32ab3..b949d4558 100644 --- a/src/hera/workflows/_runner/util.py +++ b/src/hera/workflows/_runner/util.py @@ -6,13 +6,19 @@ import inspect import json import os +import sys from pathlib import Path from typing import Any, Callable, Dict, List, Optional, cast +if sys.version_info >= (3, 10): + from types import NoneType +else: + NoneType = type(None) + from hera.shared._pydantic import _PYDANTIC_VERSION from hera.shared._type_util import ( get_workflow_annotation, - origin_type_issubclass, + origin_type_issubtype, unwrap_annotation, ) from hera.shared.serialization import serialize @@ -125,7 +131,7 @@ def _get_unannotated_type(key: str, f: Callable) -> Optional[type]: def _is_str_kwarg_of(key: str, f: Callable) -> bool: """Check if param `key` of function `f` has a type annotation that can be interpreted as a subclass of str.""" if func_param_annotation := _get_function_param_annotation(key, f): - return origin_type_issubclass(func_param_annotation, str) + return origin_type_issubtype(func_param_annotation, (str, NoneType)) return False diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index dce234510..cd9d419a4 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -47,7 +47,7 @@ _flag_enabled, ) from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator -from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass +from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issupertype from hera.shared.serialization import serialize from hera.workflows._context import _context from hera.workflows._meta_mixins import CallableTemplateMixin @@ -540,7 +540,9 @@ class will be used as inputs, rather than the class itself. else: default = MISSING - if origin_type_issubclass(func_param.annotation, NoneType) and (default is MISSING or default is not None): + if origin_type_issupertype(func_param.annotation, NoneType) and ( + default is MISSING or default is not None + ): raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.") parameters.append(Parameter(name=func_param.name, default=default)) diff --git a/tests/script_runner/parameter_inputs.py b/tests/script_runner/parameter_inputs.py index 980d756a9..4710dedd8 100644 --- a/tests/script_runner/parameter_inputs.py +++ b/tests/script_runner/parameter_inputs.py @@ -1,5 +1,5 @@ import json -from typing import Any, List +from typing import Any, List, Union try: from typing import Annotated @@ -76,6 +76,11 @@ def no_type_parameter(my_anything) -> Any: return my_anything +@script() +def str_or_int_parameter(my_str_or_int: Union[str, int]) -> str: + return f"type given: {type(my_str_or_int).__name__}" + + @script() def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict: return json.loads(my_json_str) diff --git a/tests/test_runner.py b/tests/test_runner.py index f9fac159f..f61546695 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -65,6 +65,18 @@ {}, id="no-type-dict", ), + pytest.param( + "tests.script_runner.parameter_inputs:str_or_int_parameter", + [{"name": "my_str_or_int", "value": "hi there"}], + "type given: str", + id="str-or-int-given-str", + ), + pytest.param( + "tests.script_runner.parameter_inputs:str_or_int_parameter", + [{"name": "my_str_or_int", "value": "3"}], + "type given: int", + id="str-or-int-given-int", + ), pytest.param( "tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict", [{"name": "my_json_str", "value": json.dumps({"my": "dict"})}], diff --git a/tests/test_unit/test_shared_type_utils.py b/tests/test_unit/test_shared_type_utils.py index d543997d8..799c1cb7a 100644 --- a/tests/test_unit/test_shared_type_utils.py +++ b/tests/test_unit/test_shared_type_utils.py @@ -1,5 +1,15 @@ +import sys from typing import List, Optional, Union +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated +if sys.version_info >= (3, 10): + from types import NoneType +else: + NoneType = type(None) + import pytest from annotated_types import Gt @@ -8,16 +18,12 @@ get_unsubscripted_type, get_workflow_annotation, is_annotated, - origin_type_issubclass, + origin_type_issubtype, + origin_type_issupertype, unwrap_annotation, ) from hera.workflows import Artifact, Parameter -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - @pytest.mark.parametrize("annotation, expected", [[Annotated[str, "some metadata"], True], [str, False]]) def test_is_annotated(annotation, expected): @@ -104,11 +110,29 @@ def test_get_unsubscripted_type(annotation, expected): @pytest.mark.parametrize( "annotation, target, expected", [ - [List[str], str, False], - [Optional[str], str, True], - [str, str, True], - [Union[int, str], int, True], + pytest.param(List[str], str, False, id="list-str-not-subtype-of-str"), + pytest.param(Optional[str], str, False, id="optional-str-not-subtype-of-str"), + pytest.param(str, str, True, id="str-is-subtype-of-str"), + pytest.param(Union[int, str], int, False, id="union-int-str-not-subtype-of-str"), + pytest.param(Optional[str], (str, NoneType), True, id="optional-str-is-subtype-of-optional-str"), + pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"), + pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"), + ], +) +def test_origin_type_issubtype(annotation, target, expected): + assert origin_type_issubtype(annotation, target) is expected + + +@pytest.mark.parametrize( + "annotation, target, expected", + [ + pytest.param(List[str], str, False, id="list-str-not-supertype-of-str"), + pytest.param(Optional[str], str, True, id="optional-str-is-supertype-of-str"), + pytest.param(str, str, True, id="str-is-supertype-of-str"), + pytest.param(Union[int, str], int, True, id="union-int-str-is-supertype-of-int"), + pytest.param(Optional[str], NoneType, True, id="optional-str-is-supertype-of-nonetype"), + pytest.param(str, NoneType, False, id="str-not-supertype-of-nonetype"), ], ) -def test_origin_type_issubclass(annotation, target, expected): - assert origin_type_issubclass(annotation, target) is expected +def test_origin_type_issupertype(annotation, target, expected): + assert origin_type_issupertype(annotation, target) is expected From 83eb2ee677ce987418a4d490770cc6069a767367 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Tue, 15 Oct 2024 12:49:58 +0100 Subject: [PATCH 2/6] Skip union tests for Pydantic v1 parse_obj_as does not appear to support union types (it is converting ints to strings). Signed-off-by: Alice Purcell --- tests/test_runner.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test_runner.py b/tests/test_runner.py index f61546695..2449ed5a6 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -65,17 +65,23 @@ {}, id="no-type-dict", ), - pytest.param( - "tests.script_runner.parameter_inputs:str_or_int_parameter", - [{"name": "my_str_or_int", "value": "hi there"}], - "type given: str", - id="str-or-int-given-str", - ), - pytest.param( - "tests.script_runner.parameter_inputs:str_or_int_parameter", - [{"name": "my_str_or_int", "value": "3"}], - "type given: int", - id="str-or-int-given-int", + *( + [ + pytest.param( + "tests.script_runner.parameter_inputs:str_or_int_parameter", + [{"name": "my_str_or_int", "value": "hi there"}], + "type given: str", + id="str-or-int-given-str", + ), + pytest.param( + "tests.script_runner.parameter_inputs:str_or_int_parameter", + [{"name": "my_str_or_int", "value": "3"}], + "type given: int", + id="str-or-int-given-int", + ), + ] + if _PYDANTIC_VERSION > 1 + else [] ), pytest.param( "tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict", From f0e1ddc381bca780a2587ea938117e5faa1700bd Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Wed, 16 Oct 2024 09:52:46 +0100 Subject: [PATCH 3/6] Document behaviour when given a tuple Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index de2e0ab9b..94501cab4 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -93,7 +93,11 @@ def get_unsubscripted_type(t: Any) -> Any: def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) -> bool: - """Return True if annotation is a subtype of type_.""" + """Return True if annotation is a subtype of type_. + + type_ may be a tuple of types, in which case return True if annotation is a subtype + of the union of the types in the tuple. + """ unwrapped_type = unwrap_annotation(annotation) origin_type = get_unsubscripted_type(unwrapped_type) if origin_type is Union or origin_type is UnionType: From 22d6bdc01fefc96afb284d499d1e16f8f4209a0a Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Fri, 11 Oct 2024 11:13:59 +0100 Subject: [PATCH 4/6] Fix runtime error in origin_type_is*type functions If origin_type_issubtype or origin_type_issupertype are passed a special form annotation, they will raise a TypeError due to passing it into issubclass. Fix this issue by first checking if the annotation is a type, and returning False if not. Unit test this with NoReturn, as we are very unlikely to ever create a special case for this type in the function. Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 4 ++-- tests/test_unit/test_shared_type_utils.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index 94501cab4..e703c18ae 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -102,7 +102,7 @@ def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) origin_type = get_unsubscripted_type(unwrapped_type) if origin_type is Union or origin_type is UnionType: return all(origin_type_issubtype(arg, type_) for arg in get_args(annotation)) - return issubclass(origin_type, type_) + return isinstance(origin_type, type) and issubclass(origin_type, type_) def origin_type_issupertype(annotation: Any, type_: type) -> bool: @@ -111,7 +111,7 @@ def origin_type_issupertype(annotation: Any, type_: type) -> bool: origin_type = get_unsubscripted_type(unwrapped_type) if origin_type is Union or origin_type is UnionType: return any(origin_type_issupertype(arg, type_) for arg in get_args(annotation)) - return issubclass(type_, origin_type) + return isinstance(origin_type, type) and issubclass(type_, origin_type) def is_subscripted(t: Any) -> bool: diff --git a/tests/test_unit/test_shared_type_utils.py b/tests/test_unit/test_shared_type_utils.py index 799c1cb7a..f15d41740 100644 --- a/tests/test_unit/test_shared_type_utils.py +++ b/tests/test_unit/test_shared_type_utils.py @@ -1,5 +1,5 @@ import sys -from typing import List, Optional, Union +from typing import List, NoReturn, Optional, Union if sys.version_info >= (3, 9): from typing import Annotated @@ -111,6 +111,7 @@ def test_get_unsubscripted_type(annotation, expected): "annotation, target, expected", [ pytest.param(List[str], str, False, id="list-str-not-subtype-of-str"), + pytest.param(NoReturn, str, False, id="special-form-does-not-raise-error"), pytest.param(Optional[str], str, False, id="optional-str-not-subtype-of-str"), pytest.param(str, str, True, id="str-is-subtype-of-str"), pytest.param(Union[int, str], int, False, id="union-int-str-not-subtype-of-str"), @@ -127,6 +128,7 @@ def test_origin_type_issubtype(annotation, target, expected): "annotation, target, expected", [ pytest.param(List[str], str, False, id="list-str-not-supertype-of-str"), + pytest.param(NoReturn, str, False, id="special-form-does-not-raise-error"), pytest.param(Optional[str], str, True, id="optional-str-is-supertype-of-str"), pytest.param(str, str, True, id="str-is-supertype-of-str"), pytest.param(Union[int, str], int, True, id="union-int-str-is-supertype-of-int"), From c434384babbfd7d3e9eecd05eb078774fb8f7b30 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Fri, 11 Oct 2024 12:56:03 +0100 Subject: [PATCH 5/6] Support annotated unions The origin_type_is* functions were accidentally calling get_args on the original annotation rather than the unwrapped type. Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 4 ++-- tests/test_unit/test_shared_type_utils.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index e703c18ae..b8a65ee6f 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -101,7 +101,7 @@ def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) unwrapped_type = unwrap_annotation(annotation) origin_type = get_unsubscripted_type(unwrapped_type) if origin_type is Union or origin_type is UnionType: - return all(origin_type_issubtype(arg, type_) for arg in get_args(annotation)) + return all(origin_type_issubtype(arg, type_) for arg in get_args(unwrapped_type)) return isinstance(origin_type, type) and issubclass(origin_type, type_) @@ -110,7 +110,7 @@ def origin_type_issupertype(annotation: Any, type_: type) -> bool: unwrapped_type = unwrap_annotation(annotation) origin_type = get_unsubscripted_type(unwrapped_type) if origin_type is Union or origin_type is UnionType: - return any(origin_type_issupertype(arg, type_) for arg in get_args(annotation)) + return any(origin_type_issupertype(arg, type_) for arg in get_args(unwrapped_type)) return isinstance(origin_type, type) and issubclass(type_, origin_type) diff --git a/tests/test_unit/test_shared_type_utils.py b/tests/test_unit/test_shared_type_utils.py index f15d41740..c4fda261a 100644 --- a/tests/test_unit/test_shared_type_utils.py +++ b/tests/test_unit/test_shared_type_utils.py @@ -116,6 +116,7 @@ def test_get_unsubscripted_type(annotation, expected): pytest.param(str, str, True, id="str-is-subtype-of-str"), pytest.param(Union[int, str], int, False, id="union-int-str-not-subtype-of-str"), pytest.param(Optional[str], (str, NoneType), True, id="optional-str-is-subtype-of-optional-str"), + pytest.param(Annotated[Optional[str], "foo"], (str, NoneType), True, id="annotated-optional"), pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"), pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"), ], @@ -133,6 +134,7 @@ def test_origin_type_issubtype(annotation, target, expected): pytest.param(str, str, True, id="str-is-supertype-of-str"), pytest.param(Union[int, str], int, True, id="union-int-str-is-supertype-of-int"), pytest.param(Optional[str], NoneType, True, id="optional-str-is-supertype-of-nonetype"), + pytest.param(Annotated[Optional[str], "foo"], NoneType, True, id="annotated-optional"), pytest.param(str, NoneType, False, id="str-not-supertype-of-nonetype"), ], ) From f3782e8c1e6c958be341df74da8ad229b2158c9f Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 21 Oct 2024 13:47:06 +0100 Subject: [PATCH 6/6] Fix str-or-int-given-int for Pydantic v1 Pydantic v1 checks union types in declaration order, returning as soon as a coercion succeeds; see https://docs.pydantic.dev/1.10/usage/model_config/#smart-union for more. Change the test str_or_int_parameter signature from `str | int` to `int | str` to work with this legacy mode. Signed-off-by: Alice Purcell --- tests/script_runner/parameter_inputs.py | 2 +- tests/test_runner.py | 28 ++++++++++--------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/tests/script_runner/parameter_inputs.py b/tests/script_runner/parameter_inputs.py index 4710dedd8..43cff08b3 100644 --- a/tests/script_runner/parameter_inputs.py +++ b/tests/script_runner/parameter_inputs.py @@ -77,7 +77,7 @@ def no_type_parameter(my_anything) -> Any: @script() -def str_or_int_parameter(my_str_or_int: Union[str, int]) -> str: +def str_or_int_parameter(my_str_or_int: Union[int, str]) -> str: return f"type given: {type(my_str_or_int).__name__}" diff --git a/tests/test_runner.py b/tests/test_runner.py index 2449ed5a6..f61546695 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -65,23 +65,17 @@ {}, id="no-type-dict", ), - *( - [ - pytest.param( - "tests.script_runner.parameter_inputs:str_or_int_parameter", - [{"name": "my_str_or_int", "value": "hi there"}], - "type given: str", - id="str-or-int-given-str", - ), - pytest.param( - "tests.script_runner.parameter_inputs:str_or_int_parameter", - [{"name": "my_str_or_int", "value": "3"}], - "type given: int", - id="str-or-int-given-int", - ), - ] - if _PYDANTIC_VERSION > 1 - else [] + pytest.param( + "tests.script_runner.parameter_inputs:str_or_int_parameter", + [{"name": "my_str_or_int", "value": "hi there"}], + "type given: str", + id="str-or-int-given-str", + ), + pytest.param( + "tests.script_runner.parameter_inputs:str_or_int_parameter", + [{"name": "my_str_or_int", "value": "3"}], + "type given: int", + id="str-or-int-given-int", ), pytest.param( "tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict",