diff --git a/docs/changelog/fragments/227.feature.rst b/docs/changelog/fragments/227.feature.rst new file mode 100644 index 00000000..d7461191 --- /dev/null +++ b/docs/changelog/fragments/227.feature.rst @@ -0,0 +1 @@ +Created the correct path for processing Required and NotRequired with stringified annotations or from __future__ import annotations diff --git a/docs/reference/integrations.rst b/docs/reference/integrations.rst index ffc667b0..b84dedaa 100644 --- a/docs/reference/integrations.rst +++ b/docs/reference/integrations.rst @@ -55,3 +55,10 @@ Known limitations: - All input fields of foreign keys and relationships are considered as optional due to user can pass only relationship instance or only foreign key value. + +TypedDict and stringified annotations or ``from __future__ import annotations``: + +Due to the way Python works with annotations, there is a bug, +when field annotation of TypedDict is stringified or ``from __future__ import annotations`` is placed +in file ``Required`` and ``NotRequired`` specifiers is ignored when ``required_keys`` and ``optional_keys`` is calculated. +Adaptix takes this into account and processes it properly. diff --git a/src/adaptix/_internal/model_tools/introspection/typed_dict.py b/src/adaptix/_internal/model_tools/introspection/typed_dict.py index f767c28c..cbc0150d 100644 --- a/src/adaptix/_internal/model_tools/introspection/typed_dict.py +++ b/src/adaptix/_internal/model_tools/introspection/typed_dict.py @@ -1,8 +1,10 @@ +import typing import warnings from types import MappingProxyType +from typing import AbstractSet, Sequence, Set, Tuple -from ...feature_requirement import HAS_PY_39 -from ...type_tools import get_all_type_hints, is_typed_dict_class +from ...feature_requirement import HAS_PY_39, HAS_TYPED_DICT_REQUIRED +from ...type_tools import BaseNormType, get_all_type_hints, is_typed_dict_class, normalize_type from ..definitions import ( FullShape, InputField, @@ -36,12 +38,38 @@ def _get_td_hints(tp): return elements +def _extract_item_type(tp) -> BaseNormType: + if tp.origin is typing.Annotated: + return tp.args[0] + return tp + + +def _fetch_required_keys( + fields: Sequence[Tuple[str, BaseNormType]], + frozen_required_keys: AbstractSet[str], +) -> Set: + required_keys = set(frozen_required_keys) + + for field_name, field_tp in fields: + require_type = _extract_item_type(field_tp) + if require_type.origin is typing.Required and field_name not in required_keys: + required_keys.add(field_name) + elif require_type.origin is typing.NotRequired and field_name in required_keys: + required_keys.remove(field_name) + + return required_keys + + +def _make_requirement_determinant_from_keys(required_fields: set): + return lambda name: name in required_fields + + if HAS_PY_39: - def _make_requirement_determinant(tp): + def _make_requirement_determinant_from_type(tp): required_fields = tp.__required_keys__ return lambda name: name in required_fields else: - def _make_requirement_determinant(tp): + def _make_requirement_determinant_from_type(tp): warnings.warn(TypedDictAt38Warning(), stacklevel=3) is_total = tp.__total__ return lambda name: is_total @@ -53,8 +81,19 @@ def get_typed_dict_shape(tp) -> FullShape: if not is_typed_dict_class(tp): raise IntrospectionImpossible - requirement_determinant = _make_requirement_determinant(tp) type_hints = _get_td_hints(tp) + + if HAS_TYPED_DICT_REQUIRED: + norm_types = [normalize_type(tp) for _, tp in type_hints] + + required_keys = _fetch_required_keys( + [(field_name, field_tp) for (field_name, _), field_tp in zip(type_hints, norm_types)], + tp.__required_keys__, + ) + requirement_determinant = _make_requirement_determinant_from_keys(required_keys) + else: + requirement_determinant = _make_requirement_determinant_from_type(tp) + return Shape( input=InputShape( constructor=tp, diff --git a/tests/unit/model_tools/introspection/test_typed_dict.py b/tests/unit/model_tools/introspection/test_typed_dict.py index ccba913b..b9ad5968 100644 --- a/tests/unit/model_tools/introspection/test_typed_dict.py +++ b/tests/unit/model_tools/introspection/test_typed_dict.py @@ -665,3 +665,163 @@ class Child(Base, total=False): ) ) ) + + +@requires(HAS_TYPED_DICT_REQUIRED) +def test_required_annotated(): + class Base(TypedDict): + f1: int + f2: typing.Annotated[typing.Required[int], "metadata"] + f3: 'typing.NotRequired[typing.Annotated[int, "metadata"]]' + + class Child(Base, total=False): + f4: int + f5: 'typing.Annotated[typing.Required[int], "metadata"]' + f6: typing.NotRequired[typing.Annotated[int, "metadata"]] + + assert ( + get_typed_dict_shape(Child) + == + Shape( + input=InputShape( + constructor=Child, + kwargs=None, + fields=( + InputField( + type=int, + id='f1', + default=NoDefault(), + is_required=True, + metadata=MappingProxyType({}), + original=None, + ), + InputField( + type=typing.Annotated[typing.Required[int], "metadata"], + id='f2', + default=NoDefault(), + is_required=True, + metadata=MappingProxyType({}), + original=None, + ), + InputField( + type=typing.NotRequired[typing.Annotated[int, "metadata"]], + id='f3', + default=NoDefault(), + is_required=False, + metadata=MappingProxyType({}), + original=None, + ), + InputField( + type=int, + id='f4', + default=NoDefault(), + is_required=False, + metadata=MappingProxyType({}), + original=None, + ), + InputField( + type=typing.Annotated[typing.Required[int], "metadata"], + id='f5', + default=NoDefault(), + is_required=True, + metadata=MappingProxyType({}), + original=None, + ), + InputField( + type=typing.NotRequired[typing.Annotated[int, "metadata"]], + id='f6', + default=NoDefault(), + is_required=False, + metadata=MappingProxyType({}), + original=None, + ), + ), + params=( + Param( + field_id='f1', + name='f1', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='f2', + name='f2', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='f3', + name='f3', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='f4', + name='f4', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='f5', + name='f5', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='f6', + name='f6', + kind=ParamKind.KW_ONLY, + ), + ), + overriden_types=frozenset({}), + ), + output=OutputShape( + fields=( + OutputField( + type=int, + id='f1', + default=NoDefault(), + accessor=create_key_accessor('f1', access_error=None), + metadata=MappingProxyType({}), + original=None, + ), + OutputField( + type=typing.Annotated[typing.Required[int], "metadata"], + id='f2', + default=NoDefault(), + accessor=create_key_accessor('f2', access_error=None), + metadata=MappingProxyType({}), + original=None, + ), + OutputField( + type=typing.NotRequired[typing.Annotated[int, "metadata"]], + id='f3', + default=NoDefault(), + accessor=create_key_accessor('f3', access_error=KeyError), + metadata=MappingProxyType({}), + original=None, + ), + OutputField( + type=int, + id='f4', + default=NoDefault(), + accessor=create_key_accessor('f4', access_error=KeyError), + metadata=MappingProxyType({}), + original=None, + ), + OutputField( + type=typing.Annotated[typing.Required[int], "metadata"], + id='f5', + default=NoDefault(), + accessor=create_key_accessor('f5', access_error=None), + metadata=MappingProxyType({}), + original=None, + ), + OutputField( + type=typing.NotRequired[typing.Annotated[int, "metadata"]], + id='f6', + default=NoDefault(), + accessor=create_key_accessor('f6', access_error=KeyError), + metadata=MappingProxyType({}), + original=None, + ), + ), + overriden_types=frozenset({}), + ) + ) + )