Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correcting required/optional keys fields of TypedDict #230

Merged
merged 6 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/fragments/227.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Created the correct path for processing Required and NotRequired with stringified annotations or from __future__ import annotations
7 changes: 7 additions & 0 deletions docs/reference/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,10 @@ Known limitations:

- All input fields of foreign keys and relationships are considered as optional
due to user can pass only relationship instance or only foreign key value.

TypedDict and stringified annotations or ``from __future__ import annotations``:

Due to the way Python works with annotations, there is a bug,
when field annotation of TypedDict is stringified or ``from __future__ import annotations`` is placed
in file ``Required`` and ``NotRequired`` specifiers is ignored when ``required_keys`` and ``optional_keys`` is calculated.
Adaptix takes this into account and processes it properly.
49 changes: 44 additions & 5 deletions src/adaptix/_internal/model_tools/introspection/typed_dict.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import typing
import warnings
from types import MappingProxyType
from typing import AbstractSet, Sequence, Set, Tuple

from ...feature_requirement import HAS_PY_39
from ...type_tools import get_all_type_hints, is_typed_dict_class
from ...feature_requirement import HAS_PY_39, HAS_TYPED_DICT_REQUIRED
from ...type_tools import BaseNormType, get_all_type_hints, is_typed_dict_class, normalize_type
from ..definitions import (
FullShape,
InputField,
Expand Down Expand Up @@ -36,12 +38,38 @@ def _get_td_hints(tp):
return elements


def _extract_item_type(tp) -> BaseNormType:
if tp.origin is typing.Annotated:
return tp.args[0]
return tp


def _fetch_required_keys(
fields: Sequence[Tuple[str, BaseNormType]],
frozen_required_keys: AbstractSet[str],
) -> Set:
required_keys = set(frozen_required_keys)

for field_name, field_tp in fields:
require_type = _extract_item_type(field_tp)
if require_type.origin is typing.Required and field_name not in required_keys:
required_keys.add(field_name)
elif require_type.origin is typing.NotRequired and field_name in required_keys:
required_keys.remove(field_name)

return required_keys


def _make_requirement_determinant_from_keys(required_fields: set):
return lambda name: name in required_fields


if HAS_PY_39:
def _make_requirement_determinant(tp):
def _make_requirement_determinant_from_type(tp):
required_fields = tp.__required_keys__
return lambda name: name in required_fields
else:
def _make_requirement_determinant(tp):
def _make_requirement_determinant_from_type(tp):
warnings.warn(TypedDictAt38Warning(), stacklevel=3)
is_total = tp.__total__
return lambda name: is_total
Expand All @@ -53,8 +81,19 @@ def get_typed_dict_shape(tp) -> FullShape:
if not is_typed_dict_class(tp):
raise IntrospectionImpossible

requirement_determinant = _make_requirement_determinant(tp)
type_hints = _get_td_hints(tp)

if HAS_TYPED_DICT_REQUIRED:
norm_types = [normalize_type(tp) for _, tp in type_hints]

required_keys = _fetch_required_keys(
[(field_name, field_tp) for (field_name, _), field_tp in zip(type_hints, norm_types)],
tp.__required_keys__,
)
requirement_determinant = _make_requirement_determinant_from_keys(required_keys)
else:
requirement_determinant = _make_requirement_determinant_from_type(tp)

return Shape(
input=InputShape(
constructor=tp,
Expand Down
160 changes: 160 additions & 0 deletions tests/unit/model_tools/introspection/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,163 @@ class Child(Base, total=False):
)
)
)


@requires(HAS_TYPED_DICT_REQUIRED)
def test_required_annotated():
class Base(TypedDict):
f1: int
f2: typing.Annotated[typing.Required[int], "metadata"]
f3: 'typing.NotRequired[typing.Annotated[int, "metadata"]]'

class Child(Base, total=False):
f4: int
f5: 'typing.Annotated[typing.Required[int], "metadata"]'
f6: typing.NotRequired[typing.Annotated[int, "metadata"]]

assert (
get_typed_dict_shape(Child)
==
Shape(
input=InputShape(
constructor=Child,
kwargs=None,
fields=(
InputField(
type=int,
id='f1',
default=NoDefault(),
is_required=True,
metadata=MappingProxyType({}),
original=None,
),
InputField(
type=typing.Annotated[typing.Required[int], "metadata"],
id='f2',
default=NoDefault(),
is_required=True,
metadata=MappingProxyType({}),
original=None,
),
InputField(
type=typing.NotRequired[typing.Annotated[int, "metadata"]],
id='f3',
default=NoDefault(),
is_required=False,
metadata=MappingProxyType({}),
original=None,
),
InputField(
type=int,
id='f4',
default=NoDefault(),
is_required=False,
metadata=MappingProxyType({}),
original=None,
),
InputField(
type=typing.Annotated[typing.Required[int], "metadata"],
id='f5',
default=NoDefault(),
is_required=True,
metadata=MappingProxyType({}),
original=None,
),
InputField(
type=typing.NotRequired[typing.Annotated[int, "metadata"]],
id='f6',
default=NoDefault(),
is_required=False,
metadata=MappingProxyType({}),
original=None,
),
),
params=(
Param(
field_id='f1',
name='f1',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='f2',
name='f2',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='f3',
name='f3',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='f4',
name='f4',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='f5',
name='f5',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='f6',
name='f6',
kind=ParamKind.KW_ONLY,
),
),
overriden_types=frozenset({}),
),
output=OutputShape(
fields=(
OutputField(
type=int,
id='f1',
default=NoDefault(),
accessor=create_key_accessor('f1', access_error=None),
metadata=MappingProxyType({}),
original=None,
),
OutputField(
type=typing.Annotated[typing.Required[int], "metadata"],
id='f2',
default=NoDefault(),
accessor=create_key_accessor('f2', access_error=None),
metadata=MappingProxyType({}),
original=None,
),
OutputField(
type=typing.NotRequired[typing.Annotated[int, "metadata"]],
id='f3',
default=NoDefault(),
accessor=create_key_accessor('f3', access_error=KeyError),
metadata=MappingProxyType({}),
original=None,
),
OutputField(
type=int,
id='f4',
default=NoDefault(),
accessor=create_key_accessor('f4', access_error=KeyError),
metadata=MappingProxyType({}),
original=None,
),
OutputField(
type=typing.Annotated[typing.Required[int], "metadata"],
id='f5',
default=NoDefault(),
accessor=create_key_accessor('f5', access_error=None),
metadata=MappingProxyType({}),
original=None,
),
OutputField(
type=typing.NotRequired[typing.Annotated[int, "metadata"]],
id='f6',
default=NoDefault(),
accessor=create_key_accessor('f6', access_error=KeyError),
metadata=MappingProxyType({}),
original=None,
),
),
overriden_types=frozenset({}),
)
)
)
Loading