Skip to content

Commit

Permalink
Fix deserialization for str unions (#1239)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [ ] Fixes #<!--issue number goes here-->
- [X] Tests added
- [ ] Documentation/examples added
- [X] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
PR #1168 changed the logic of `map_runner_input` and `_parse` to no
longer pass incoming values to `json.loads` if their annotated type was
a union that included str, rather than only when given a subtype of
`Optional[str]`. This makes it impossible to, for instance, pass an int
to a `Union[str, int]`.

This PR splits `origin_type_issubclass` into two functions,
`origin_type_issupertype` (which matches the previous behaviour) and
`origin_type_issubtype`, and use the latter instead to restore the
original behaviour.

Additionally, it fixes a bug where we were passing an annotation to
`issubclass` without first checking if it's a type, resulting in a
`TypeError` (partially addresses #1173).

---------

Signed-off-by: Alice Purcell <[email protected]>
  • Loading branch information
alicederyn authored Oct 21, 2024
1 parent 30cb592 commit 7df9b8e
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 25 deletions.
23 changes: 18 additions & 5 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,26 @@ def get_unsubscripted_type(t: Any) -> Any:
return t


def origin_type_issubclass(cls: Any, type_: type) -> bool:
"""Return True if cls can be considered as a subclass of type_."""
unwrapped_type = unwrap_annotation(cls)
def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) -> bool:
"""Return True if annotation is a subtype of type_.
type_ may be a tuple of types, in which case return True if annotation is a subtype
of the union of the types in the tuple.
"""
unwrapped_type = unwrap_annotation(annotation)
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return all(origin_type_issubtype(arg, type_) for arg in get_args(unwrapped_type))
return isinstance(origin_type, type) and issubclass(origin_type, type_)


def origin_type_issupertype(annotation: Any, type_: type) -> bool:
"""Return True if annotation is a supertype of type_."""
unwrapped_type = unwrap_annotation(annotation)
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return any(origin_type_issubclass(arg, type_) for arg in get_args(cls))
return issubclass(origin_type, type_)
return any(origin_type_issupertype(arg, type_) for arg in get_args(unwrapped_type))
return isinstance(origin_type, type) and issubclass(type_, origin_type)


def is_subscripted(t: Any) -> bool:
Expand Down
10 changes: 8 additions & 2 deletions src/hera/workflows/_runner/script_annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
import inspect
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast

if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

from hera.shared._pydantic import BaseModel, get_field_annotations, get_fields
from hera.shared._type_util import (
get_unsubscripted_type,
get_workflow_annotation,
is_subscripted,
origin_type_issubclass,
origin_type_issubtype,
unwrap_annotation,
)
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -138,7 +144,7 @@ def map_runner_input(
input_model_obj = {}

def load_parameter_value(value: str, value_type: type) -> Any:
if origin_type_issubclass(value_type, str):
if origin_type_issubtype(value_type, (str, NoneType)):
return value

try:
Expand Down
10 changes: 8 additions & 2 deletions src/hera/workflows/_runner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
import inspect
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast

if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

from hera.shared._pydantic import _PYDANTIC_VERSION
from hera.shared._type_util import (
get_workflow_annotation,
origin_type_issubclass,
origin_type_issubtype,
unwrap_annotation,
)
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -125,7 +131,7 @@ def _get_unannotated_type(key: str, f: Callable) -> Optional[type]:
def _is_str_kwarg_of(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` has a type annotation that can be interpreted as a subclass of str."""
if func_param_annotation := _get_function_param_annotation(key, f):
return origin_type_issubclass(func_param_annotation, str)
return origin_type_issubtype(func_param_annotation, (str, NoneType))
return False


Expand Down
6 changes: 4 additions & 2 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
_flag_enabled,
)
from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator
from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass
from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issupertype
from hera.shared.serialization import serialize
from hera.workflows._context import _context
from hera.workflows._meta_mixins import CallableTemplateMixin
Expand Down Expand Up @@ -540,7 +540,9 @@ class will be used as inputs, rather than the class itself.
else:
default = MISSING

if origin_type_issubclass(func_param.annotation, NoneType) and (default is MISSING or default is not None):
if origin_type_issupertype(func_param.annotation, NoneType) and (
default is MISSING or default is not None
):
raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.")

parameters.append(Parameter(name=func_param.name, default=default))
Expand Down
7 changes: 6 additions & 1 deletion tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List
from typing import Any, List, Union

try:
from typing import Annotated
Expand Down Expand Up @@ -76,6 +76,11 @@ def no_type_parameter(my_anything) -> Any:
return my_anything


@script()
def str_or_int_parameter(my_str_or_int: Union[int, str]) -> str:
return f"type given: {type(my_str_or_int).__name__}"


@script()
def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict:
return json.loads(my_json_str)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@
{},
id="no-type-dict",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_or_int_parameter",
[{"name": "my_str_or_int", "value": "hi there"}],
"type given: str",
id="str-or-int-given-str",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_or_int_parameter",
[{"name": "my_str_or_int", "value": "3"}],
"type given: int",
id="str-or-int-given-int",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand Down
54 changes: 41 additions & 13 deletions tests/test_unit/test_shared_type_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from typing import List, Optional, Union
import sys
from typing import List, NoReturn, Optional, Union

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

import pytest
from annotated_types import Gt
Expand All @@ -8,16 +18,12 @@
get_unsubscripted_type,
get_workflow_annotation,
is_annotated,
origin_type_issubclass,
origin_type_issubtype,
origin_type_issupertype,
unwrap_annotation,
)
from hera.workflows import Artifact, Parameter

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


@pytest.mark.parametrize("annotation, expected", [[Annotated[str, "some metadata"], True], [str, False]])
def test_is_annotated(annotation, expected):
Expand Down Expand Up @@ -104,11 +110,33 @@ def test_get_unsubscripted_type(annotation, expected):
@pytest.mark.parametrize(
"annotation, target, expected",
[
[List[str], str, False],
[Optional[str], str, True],
[str, str, True],
[Union[int, str], int, True],
pytest.param(List[str], str, False, id="list-str-not-subtype-of-str"),
pytest.param(NoReturn, str, False, id="special-form-does-not-raise-error"),
pytest.param(Optional[str], str, False, id="optional-str-not-subtype-of-str"),
pytest.param(str, str, True, id="str-is-subtype-of-str"),
pytest.param(Union[int, str], int, False, id="union-int-str-not-subtype-of-str"),
pytest.param(Optional[str], (str, NoneType), True, id="optional-str-is-subtype-of-optional-str"),
pytest.param(Annotated[Optional[str], "foo"], (str, NoneType), True, id="annotated-optional"),
pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"),
pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"),
],
)
def test_origin_type_issubtype(annotation, target, expected):
assert origin_type_issubtype(annotation, target) is expected


@pytest.mark.parametrize(
"annotation, target, expected",
[
pytest.param(List[str], str, False, id="list-str-not-supertype-of-str"),
pytest.param(NoReturn, str, False, id="special-form-does-not-raise-error"),
pytest.param(Optional[str], str, True, id="optional-str-is-supertype-of-str"),
pytest.param(str, str, True, id="str-is-supertype-of-str"),
pytest.param(Union[int, str], int, True, id="union-int-str-is-supertype-of-int"),
pytest.param(Optional[str], NoneType, True, id="optional-str-is-supertype-of-nonetype"),
pytest.param(Annotated[Optional[str], "foo"], NoneType, True, id="annotated-optional"),
pytest.param(str, NoneType, False, id="str-not-supertype-of-nonetype"),
],
)
def test_origin_type_issubclass(annotation, target, expected):
assert origin_type_issubclass(annotation, target) is expected
def test_origin_type_issupertype(annotation, target, expected):
assert origin_type_issupertype(annotation, target) is expected

0 comments on commit 7df9b8e

Please sign in to comment.