Skip to content

Commit

Permalink
Merge pull request #235 from reagento/fix/optional-field-skipping
Browse files Browse the repository at this point in the history
allow skipping optional fields using polices
  • Loading branch information
zhPavel authored Feb 11, 2024
2 parents a0e42b3 + fb09de1 commit c712b81
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 77 deletions.
5 changes: 2 additions & 3 deletions src/adaptix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from adaptix._internal.definitions import DebugTrail

from ._internal.common import Dumper, Loader, TypeHint
from ._internal.definitions import DebugTrail
from ._internal.model_tools.introspection.typed_dict import TypedDictAt38Warning
from ._internal.morphing.facade.func import dump, load
from ._internal.morphing.facade.provider import (
as_is_dumper,
as_is_loader,
bound,
constructor,
default_dict,
dumper,
Expand All @@ -31,6 +29,7 @@
)
from ._internal.morphing.name_layout.base import ExtraIn, ExtraOut
from ._internal.name_style import NameStyle
from ._internal.provider.facade.provider import bound
from ._internal.utils import Omittable, Omitted
from .provider import (
AggregateCannotProvide,
Expand Down
17 changes: 9 additions & 8 deletions src/adaptix/_internal/conversion/broaching/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,18 @@ class UnpackMapping(Generic[PlanT]):
element: PlanT


FuncCallArg = Union[
PositionalArg[PlanT],
KeywordArg[PlanT],
UnpackIterable[PlanT],
UnpackMapping[PlanT],
]


@dataclass(frozen=True)
class FunctionElement(BasePlanElement, Generic[PlanT]):
func: Callable[..., Any]
args: VarTuple[
Union[
PositionalArg[PlanT],
KeywordArg[PlanT],
UnpackIterable[PlanT],
UnpackMapping[PlanT],
]
]
args: VarTuple[FuncCallArg[PlanT]]


@dataclass(frozen=True)
Expand Down
116 changes: 87 additions & 29 deletions src/adaptix/_internal/conversion/converter_provider.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from abc import ABC, abstractmethod
from functools import reduce
from inspect import Parameter, Signature
from typing import Any, Iterable, Optional, Sequence, cast, final
from typing import Any, Iterable, List, Mapping, Optional, Sequence, Tuple, cast, final

from ..code_tools.compiler import BasicClosureCompiler, ClosureCompiler
from ..common import Converter, TypeHint
from ..conversion.broaching.code_generator import BroachingCodeGenerator, BroachingPlan, BuiltinBroachingCodeGenerator
from ..conversion.broaching.definitions import (
AccessorElement,
FuncCallArg,
FunctionElement,
KeywordArg,
ParameterElement,
Expand All @@ -20,6 +21,7 @@
BindingSource,
CoercerRequest,
ConverterRequest,
UnboundOptionalPolicyRequest,
)
from ..model_tools.definitions import BaseField, DefaultValue, InputField, InputShape, NoDefault, OutputShape, ParamKind
from ..morphing.model.basic_gen import NameSanitizer, compile_closure_with_globals_capturing, fetch_code_gen_hook
Expand All @@ -28,6 +30,7 @@
from ..provider.request_cls import LocMap, LocStack, TypeHintLoc
from ..provider.shape_provider import InputShapeRequest, OutputShapeRequest, provide_generic_resolved_shape
from ..provider.static_provider import StaticProvider, static_provision_action
from ..utils import add_note


class ConverterProvider(StaticProvider, ABC):
Expand Down Expand Up @@ -168,25 +171,46 @@ def _fetch_bindings(
owner_binding_src: BindingSource,
owner_binding_dst: BindingDest,
extra_params: Sequence[BindingSource],
) -> Iterable[BindingResult]:
) -> Iterable[Tuple[InputField, Optional[BindingResult]]]:
model_binding_sources = tuple(
owner_binding_src.append_with(src_field)
for src_field in src_shape.fields
)
bindings = mediator.mandatory_provide_by_iterable(
[
BindingRequest(
sources=(
model_binding_sources,
*extra_params,
),
destination=owner_binding_dst.append_with(dst_field),
sources = (model_binding_sources, *extra_params)

def fetch_field_binding(dst_field: InputField) -> Tuple[InputField, Optional[BindingResult]]:
destination = owner_binding_dst.append_with(dst_field)
try:
binding = mediator.provide(
BindingRequest(
sources=sources, # type: ignore[arg-type]
destination=destination,
)
)
for dst_field in dst_shape.fields
],
except CannotProvide as e:
if dst_field.is_required:
add_note(e, 'Note: This is a required filed, so it must take value')
raise

policy = mediator.mandatory_provide(
UnboundOptionalPolicyRequest(loc_stack=destination.to_loc_stack())
)
if policy.is_allowed:
return dst_field, None
add_note(
e,
'Note: Current policy limits unbound optional fields,'
' so you need to link it to another field'
' or explicitly confirm the desire to skipping using `allow_unbound_optional`'
)
raise
return dst_field, binding

return mandatory_apply_by_iterable(
fetch_field_binding,
zip(dst_shape.fields),
lambda: 'Bindings for some fields are not found',
)
return bindings

def _get_nested_models_sub_plan(
self,
Expand Down Expand Up @@ -249,14 +273,13 @@ def _get_coercer_sub_plan(
),
)

def _generate_binding_sub_plans(
def _generate_field_to_sub_plan(
self,
mediator: Mediator,
dst_shape: InputShape,
extra_params: Sequence[BindingSource],
bindings: Iterable[BindingResult],
field_bindings: Iterable[Tuple[InputField, BindingResult]],
owner_binding_dst: BindingDest,
) -> Iterable[BroachingPlan]:
) -> Mapping[InputField, BroachingPlan]:
def generate_sub_plan(input_field: InputField, binding: BindingResult):
binding_dst = owner_binding_dst.append_with(input_field)
try:
Expand All @@ -276,11 +299,15 @@ def generate_sub_plan(input_field: InputField, binding: BindingResult):
return result
raise e

return mandatory_apply_by_iterable(
coercers = mandatory_apply_by_iterable(
generate_sub_plan,
zip(dst_shape.fields, bindings),
field_bindings,
lambda: 'Coercers for some bindings are not found',
)
return {
dst_field: coercer
for (dst_field, binding), coercer in zip(field_bindings, coercers)
}

def _make_broaching_plan(
self,
Expand All @@ -291,27 +318,58 @@ def _make_broaching_plan(
owner_binding_src: BindingSource,
owner_binding_dst: BindingDest,
) -> BroachingPlan:
bindings = self._fetch_bindings(
field_bindings = self._fetch_bindings(
mediator=mediator,
dst_shape=dst_shape,
src_shape=src_shape,
extra_params=extra_params,
owner_binding_src=owner_binding_src,
owner_binding_dst=owner_binding_dst,
)
sub_plans = self._generate_binding_sub_plans(
field_to_sub_plan = self._generate_field_to_sub_plan(
mediator=mediator,
dst_shape=dst_shape,
bindings=bindings,
field_bindings=[
(dst_field, binding)
for dst_field, binding in field_bindings
if binding is not None
],
extra_params=extra_params,
owner_binding_dst=owner_binding_dst,
)
return self._make_constructor_call(
dst_shape=dst_shape,
field_to_binding=dict(field_bindings),
field_to_sub_plan=field_to_sub_plan,
)

def _make_constructor_call(
self,
dst_shape: InputShape,
field_to_binding: Mapping[InputField, Optional[BindingResult]],
field_to_sub_plan: Mapping[InputField, BroachingPlan],
) -> BroachingPlan:
args: List[FuncCallArg[BroachingPlan]] = []
has_skipped_params = False
for param in dst_shape.params:
field = dst_shape.fields_dict[param.field_id]

if field_to_binding[field] is None:
has_skipped_params = True
continue

sub_plan = field_to_sub_plan[field]
if param.kind == ParamKind.KW_ONLY or has_skipped_params:
args.append(KeywordArg(param.name, sub_plan))
elif param.kind == ParamKind.POS_ONLY and has_skipped_params:
raise CannotProvide(
'Can not generate consistent constructor call,'
' positional-only parameter is skipped',
is_demonstrative=True,
)
else:
args.append(PositionalArg(sub_plan))

return FunctionElement(
func=dst_shape.constructor,
args=tuple(
KeywordArg(param.name, sub_plan)
if param.kind == ParamKind.KW_ONLY else
PositionalArg(sub_plan)
for param, binding, sub_plan in zip(dst_shape.params, bindings, sub_plans)
),
args=tuple(args),
)
28 changes: 26 additions & 2 deletions src/adaptix/_internal/conversion/facade/provider.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from ...common import Coercer
from ...provider.essential import Provider
from ...provider.facade.provider import bound_by_any
from ...provider.loc_stack_filtering import Pred, create_loc_stack_checker
from ..binding_provider import MatchingBindingProvider
from ..coercer_provider import MatchingCoercerProvider
from ..policy_provider import UnboundOptionalPolicyProvider


def bind(src: Pred, dst: Pred) -> Provider:
"""Basic provider to define custom binding between fields.
:param src: Predicate specifying source point of binding. See :ref:`predicate-system` for details.
:param dst: Predicate specifying destination point of binding. See :ref:`predicate-system` for details.
:return: desired provider
:return: Desired provider
"""
return MatchingBindingProvider(
src_lsc=create_loc_stack_checker(src),
Expand All @@ -24,10 +26,32 @@ def coercer(src: Pred, dst: Pred, func: Coercer) -> Provider:
:param src: Predicate specifying source point of binding. See :ref:`predicate-system` for details.
:param dst: Predicate specifying destination point of binding. See :ref:`predicate-system` for details.
:param func: The function is used to transform input data to a destination type.
:return: desired provider
:return: Desired provider
"""
return MatchingCoercerProvider(
src_lsc=create_loc_stack_checker(src),
dst_lsc=create_loc_stack_checker(dst),
coercer=func,
)


def allow_unbound_optional(*preds: Pred) -> Provider:
"""Sets policy to permit optional fields that does not bound to any source field.
:param preds: Predicate specifying target of policy.
Each predicate is merged via ``|`` operator.
See :ref:`predicate-system` for details.
:return: Desired provider.
"""
return bound_by_any(preds, UnboundOptionalPolicyProvider(is_allowed=True))


def forbid_unbound_optional(*preds: Pred) -> Provider:
"""Sets policy to prohibit optional fields that does not bound to any source field.
:param preds: Predicate specifying target of policy.
Each predicate is merged via ``|`` operator.
See :ref:`predicate-system` for details.
:return: Desired provider.
"""
return bound_by_any(preds, UnboundOptionalPolicyProvider(is_allowed=False))
7 changes: 5 additions & 2 deletions src/adaptix/_internal/conversion/facade/retort.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from inspect import Signature
from typing import Any, Callable, Iterable, Optional, TypeVar

from adaptix import Provider

from ...provider.essential import Provider
from ...provider.loc_stack_filtering import P
from ...provider.shape_provider import BUILTIN_SHAPE_PROVIDER
from ...retort.operating_retort import OperatingRetort
from ..binding_provider import SameNameBindingProvider
from ..coercer_provider import DstAnyCoercerProvider, SameTypeCoercerProvider, SubclassCoercerProvider
from ..converter_provider import BuiltinConverterProvider
from ..request_cls import ConverterRequest
from .provider import forbid_unbound_optional


class FilledConverterRetort(OperatingRetort):
Expand All @@ -22,6 +23,8 @@ class FilledConverterRetort(OperatingRetort):
SameTypeCoercerProvider(),
DstAnyCoercerProvider(),
SubclassCoercerProvider(),

forbid_unbound_optional(P.ANY),
]


Expand Down
16 changes: 16 additions & 0 deletions src/adaptix/_internal/conversion/policy_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ..provider.essential import Mediator
from ..provider.static_provider import StaticProvider, static_provision_action
from .request_cls import UnboundOptionalPolicy, UnboundOptionalPolicyRequest


class UnboundOptionalPolicyProvider(StaticProvider):
def __init__(self, is_allowed: bool):
self._is_allowed = is_allowed

@static_provision_action
def _outer_unbound_optional_policy(
self,
mediator: Mediator,
request: UnboundOptionalPolicyRequest,
) -> UnboundOptionalPolicy:
return UnboundOptionalPolicy(is_allowed=self._is_allowed)
12 changes: 11 additions & 1 deletion src/adaptix/_internal/conversion/request_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..model_tools.definitions import BaseField, InputField, OutputField
from ..provider.essential import Request
from ..provider.fields import base_field_to_loc_map, input_field_to_loc_map, output_field_to_loc_map
from ..provider.request_cls import LocStack
from ..provider.request_cls import LocatedRequest, LocStack

BindingSourceItem = Union[OutputField, BaseField]

Expand Down Expand Up @@ -57,6 +57,16 @@ class CoercerRequest(Request[Coercer]):
dst: BindingDest


@dataclass(frozen=True)
class UnboundOptionalPolicy:
is_allowed: bool


@dataclass(frozen=True)
class UnboundOptionalPolicyRequest(LocatedRequest):
pass


@dataclass(frozen=True)
class ConverterRequest(Request):
signature: Signature
Expand Down
Loading

0 comments on commit c712b81

Please sign in to comment.