From 5c77a6d3e652997d11aca5b739937ba967eece03 Mon Sep 17 00:00:00 2001 From: pavel Date: Sun, 11 Feb 2024 00:46:00 +0300 Subject: [PATCH] fix/optional-field-skipping --- docs/changelog/fragments/229.bugfix.rst | 1 + .../_internal/morphing/model/basic_gen.py | 67 ++++++------------- .../_internal/morphing/model/dumper_gen.py | 5 +- .../morphing/model/dumper_provider.py | 8 +-- .../_internal/morphing/model/loader_gen.py | 18 +++-- .../morphing/model/loader_provider.py | 37 +++++++--- .../morphing/model/test_loader_provider.py | 26 +++++++ 7 files changed, 93 insertions(+), 69 deletions(-) create mode 100644 docs/changelog/fragments/229.bugfix.rst diff --git a/docs/changelog/fragments/229.bugfix.rst b/docs/changelog/fragments/229.bugfix.rst new file mode 100644 index 00000000..0b9458ef --- /dev/null +++ b/docs/changelog/fragments/229.bugfix.rst @@ -0,0 +1 @@ +Fixed parameter shuffling on skipping optional field diff --git a/src/adaptix/_internal/morphing/model/basic_gen.py b/src/adaptix/_internal/morphing/model/basic_gen.py index 72d26825..67f6caa1 100644 --- a/src/adaptix/_internal/morphing/model/basic_gen.py +++ b/src/adaptix/_internal/morphing/model/basic_gen.py @@ -2,13 +2,27 @@ import re import string from abc import ABC, abstractmethod -from dataclasses import dataclass, replace -from typing import Any, Callable, Collection, Container, Dict, Iterable, List, Mapping, Set, Tuple, TypeVar, Union +from dataclasses import dataclass +from typing import ( + AbstractSet, + Any, + Callable, + Collection, + Container, + Dict, + Iterable, + List, + Mapping, + Set, + Tuple, + TypeVar, + Union, +) from ...code_tools.code_builder import CodeBuilder from ...code_tools.compiler import ClosureCompiler from ...code_tools.utils import get_literal_expr -from ...model_tools.definitions import InputField, InputShape, OutputField, OutputShape +from ...model_tools.definitions import InputField, OutputField from ...provider.essential import CannotProvide, Mediator from ...provider.request_cls import LocatedRequest, LocStack, TypeHintLoc from ...provider.static_provider import StaticProvider, static_provision_action @@ -123,17 +137,17 @@ def _collect_used_direct_fields(crown: BaseCrown) -> Set[str]: return used_set -def get_skipped_fields(shape: BaseShape, name_layout: BaseNameLayout) -> Collection[str]: +def get_skipped_fields(shape: BaseShape, name_layout: BaseNameLayout) -> AbstractSet[str]: used_direct_fields = _collect_used_direct_fields(name_layout.crown) if isinstance(name_layout.extra_move, ExtraTargets): extra_targets = name_layout.extra_move.fields else: extra_targets = () - return [ + return { field.id for field in shape.fields if field.id not in used_direct_fields and field.id not in extra_targets - ] + } def _inner_get_extra_targets_at_crown(extra_targets: Container[str], crown: BaseCrown) -> Collection[str]: @@ -196,47 +210,6 @@ def get_wild_extra_targets(shape: BaseShape, extra_move: Union[InpExtraMove, Out ] -def strip_input_shape_fields(shape: InputShape, skipped_fields: Collection[str]) -> InputShape: - skipped_required_fields = [ - field.id - for field in shape.fields - if field.is_required and field.id in skipped_fields - ] - if skipped_required_fields: - raise ValueError( - f"Required fields {skipped_required_fields} are skipped" - ) - return replace( - shape, - fields=tuple( - field for field in shape.fields - if field.id not in skipped_fields - ), - params=tuple( - param for param in shape.params - if param.field_id not in skipped_fields - ), - overriden_types=frozenset( - field.id for field in shape.fields - if field.id not in skipped_fields - ), - ) - - -def strip_output_shape_fields(shape: OutputShape, skipped_fields: Collection[str]) -> OutputShape: - return replace( - shape, - fields=tuple( - field for field in shape.fields - if field.id not in skipped_fields - ), - overriden_types=frozenset( - field.id for field in shape.fields - if field.id not in skipped_fields - ) - ) - - class NameSanitizer: _BAD_CHARS = re.compile(r'\W') _TRANSLATE_MAP = str.maketrans({'.': '_', '[': '_'}) diff --git a/src/adaptix/_internal/morphing/model/dumper_gen.py b/src/adaptix/_internal/morphing/model/dumper_gen.py index f46b4188..927292a7 100644 --- a/src/adaptix/_internal/morphing/model/dumper_gen.py +++ b/src/adaptix/_internal/morphing/model/dumper_gen.py @@ -19,7 +19,7 @@ ) from ...special_cases_optimization import as_is_stub, get_default_clause from ...struct_trail import append_trail, extend_trail, render_trail_as_note -from .basic_gen import ModelDumperGen +from .basic_gen import ModelDumperGen, get_skipped_fields from .crown_definitions import ( CrownPath, CrownPathElem, @@ -132,7 +132,10 @@ def produce_code(self, closure_name: str) -> Tuple[str, Mapping[str, object]]: body_builder("opt_fields = {}") body_builder.empty_line() + skipped_fields = get_skipped_fields(self._shape, self._name_layout) for field in self._shape.fields: + if field.id in skipped_fields: + continue if not self._is_extra_target(field): self._gen_field_extraction( body_builder, namespace, field, diff --git a/src/adaptix/_internal/morphing/model/dumper_provider.py b/src/adaptix/_internal/morphing/model/dumper_provider.py index 0fc1019a..fa4300db 100644 --- a/src/adaptix/_internal/morphing/model/dumper_provider.py +++ b/src/adaptix/_internal/morphing/model/dumper_provider.py @@ -18,9 +18,7 @@ fetch_code_gen_hook, get_extra_targets_at_crown, get_optional_fields_at_list_crown, - get_skipped_fields, get_wild_extra_targets, - strip_output_shape_fields, ) from .crown_definitions import OutputNameLayout, OutputNameLayoutRequest from .dumper_gen import BuiltinModelDumperGen @@ -46,7 +44,7 @@ def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper: def _fetch_model_dumper_gen(self, mediator: Mediator, request: DumperRequest) -> ModelDumperGen: shape = self._fetch_shape(mediator, request) name_layout = self._fetch_name_layout(mediator, request, shape) - shape = self._process_shape(shape, name_layout) + self._validate_params(shape, name_layout) fields_dumpers = self._fetch_field_dumpers(mediator, request, shape) debug_trail = mediator.mandatory_provide(DebugTrailRequest(loc_stack=request.loc_stack)) @@ -137,7 +135,7 @@ def _fetch_field_dumpers( ) return {field.id: dumper for field, dumper in zip(shape.fields, dumpers)} - def _process_shape(self, shape: OutputShape, name_layout: OutputNameLayout) -> OutputShape: + def _validate_params(self, shape: OutputShape, name_layout: OutputNameLayout) -> None: optional_fields_at_list_crown = get_optional_fields_at_list_crown( {field.id: field for field in shape.fields}, name_layout.crown, @@ -158,5 +156,3 @@ def _process_shape(self, shape: OutputShape, name_layout: OutputNameLayout) -> O raise ValueError( f"Extra targets {extra_targets_at_crown} are found at crown" ) - - return strip_output_shape_fields(shape, get_skipped_fields(shape, name_layout)) diff --git a/src/adaptix/_internal/morphing/model/loader_gen.py b/src/adaptix/_internal/morphing/model/loader_gen.py index e007ae5e..94f0d02d 100644 --- a/src/adaptix/_internal/morphing/model/loader_gen.py +++ b/src/adaptix/_internal/morphing/model/loader_gen.py @@ -1,7 +1,7 @@ import collections.abc import contextlib from dataclasses import dataclass -from typing import Dict, List, Mapping, Optional, Set, Tuple +from typing import AbstractSet, Dict, List, Mapping, Optional, Set, Tuple from ...code_tools.cascade_namespace import BuiltinCascadeNamespace, CascadeNamespace from ...code_tools.code_builder import CodeBuilder @@ -180,6 +180,7 @@ def __init__( debug_trail: DebugTrail, strict_coercion: bool, field_loaders: Mapping[str, Loader], + skipped_fields: AbstractSet[str], model_identity: str, props: ModelLoaderProps, ): @@ -194,6 +195,7 @@ def __init__( param.field_id: param for param in self._shape.params } self._field_loaders = field_loaders + self._skipped_fields = skipped_fields self._model_identity = model_identity self._props = props @@ -303,21 +305,29 @@ def _gen_header(self, state: GenState): state.builder.extend_above(header_builder) - def _gen_constructor_call(self, state: GenState) -> None: + def _gen_constructor_call(self, state: GenState) -> None: # noqa: CCR001 state.namespace.add_constant('constructor', self._shape.constructor) constructor_builder = CodeBuilder() + has_skipped_params = False with constructor_builder("constructor("): for param in self._shape.params: field = self._shape.fields_dict[param.field_id] + if field.id in self._skipped_fields: + has_skipped_params = True + continue if self._is_packed_field(field): continue value = state.v_field(field) - - if param.kind == ParamKind.KW_ONLY: + if param.kind == ParamKind.KW_ONLY or has_skipped_params: constructor_builder(f"{param.name}={value},") + elif param.kind == ParamKind.POS_ONLY and has_skipped_params: + raise ValueError( + 'Can not generate consistent constructor call,' + ' positional-only parameter is skipped' + ) else: constructor_builder(f"{value},") diff --git a/src/adaptix/_internal/morphing/model/loader_provider.py b/src/adaptix/_internal/morphing/model/loader_provider.py index aadb2752..2b65fda9 100644 --- a/src/adaptix/_internal/morphing/model/loader_provider.py +++ b/src/adaptix/_internal/morphing/model/loader_provider.py @@ -1,4 +1,4 @@ -from typing import Mapping +from typing import AbstractSet, Mapping from adaptix._internal.provider.fields import input_field_to_loc_map @@ -22,7 +22,6 @@ get_skipped_fields, get_wild_extra_targets, has_collect_policy, - strip_input_shape_fields, ) from .crown_definitions import InputNameLayout, InputNameLayoutRequest @@ -53,8 +52,8 @@ def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader: def _fetch_model_loader_gen(self, mediator: Mediator, request: LoaderRequest) -> ModelLoaderGen: shape = self._fetch_shape(mediator, request) name_layout = self._fetch_name_layout(mediator, request, shape) - shape = self._process_shape(shape, name_layout) - self._validate_params(shape, name_layout) + skipped_fields = get_skipped_fields(shape, name_layout) + self._validate_params(shape, name_layout, skipped_fields) field_loaders = self._fetch_field_loaders(mediator, request, shape) strict_coercion = mediator.mandatory_provide(StrictCoercionRequest(loc_stack=request.loc_stack)) @@ -65,6 +64,7 @@ def _fetch_model_loader_gen(self, mediator: Mediator, request: LoaderRequest) -> shape=shape, name_layout=name_layout, field_loaders=field_loaders, + skipped_fields=skipped_fields, model_identity=self._fetch_model_identity(mediator, request, shape, name_layout), ) @@ -88,6 +88,7 @@ def _create_model_loader_gen( shape: InputShape, name_layout: InputNameLayout, field_loaders: Mapping[str, Loader], + skipped_fields: AbstractSet[str], model_identity: str, ) -> ModelLoaderGen: return BuiltinModelLoaderGen( @@ -96,6 +97,7 @@ def _create_model_loader_gen( debug_trail=debug_trail, strict_coercion=strict_coercion, field_loaders=field_loaders, + skipped_fields=skipped_fields, model_identity=model_identity, props=self._props, ) @@ -151,15 +153,22 @@ def _fetch_field_loaders( ) return {field.id: loader for field, loader in zip(shape.fields, loaders)} - def _process_shape(self, shape: InputShape, name_layout: InputNameLayout) -> InputShape: - wild_extra_targets = get_wild_extra_targets(shape, name_layout.extra_move) - if wild_extra_targets: + def _validate_params( + self, + shape: InputShape, + name_layout: InputNameLayout, + skipped_fields: AbstractSet[str], + ) -> None: + skipped_required_fields = [ + field.id + for field in shape.fields + if field.is_required and field.id in skipped_fields + ] + if skipped_required_fields: raise ValueError( - f"ExtraTargets {wild_extra_targets} are attached to non-existing fields" + f"Required fields {skipped_required_fields} are skipped" ) - return strip_input_shape_fields(shape, get_skipped_fields(shape, name_layout)) - def _validate_params(self, processed_shape: InputShape, name_layout: InputNameLayout) -> None: if name_layout.extra_move is None and has_collect_policy(name_layout.crown): raise ValueError( "Cannot create loader that collect extra data" @@ -173,7 +182,7 @@ def _validate_params(self, processed_shape: InputShape, name_layout: InputNameLa ) optional_fields_at_list_crown = get_optional_fields_at_list_crown( - {field.id: field for field in processed_shape.fields}, + {field.id: field for field in shape.fields}, name_layout.crown, ) if optional_fields_at_list_crown: @@ -181,6 +190,12 @@ def _validate_params(self, processed_shape: InputShape, name_layout: InputNameLa f"Optional fields {optional_fields_at_list_crown} are found at list crown" ) + wild_extra_targets = get_wild_extra_targets(shape, name_layout.extra_move) + if wild_extra_targets: + raise ValueError( + f"ExtraTargets {wild_extra_targets} are attached to non-existing fields" + ) + class InlinedShapeModelLoaderProvider(ModelLoaderProvider): def __init__( diff --git a/tests/unit/morphing/model/test_loader_provider.py b/tests/unit/morphing/model/test_loader_provider.py index 58577fd3..3931dcd9 100644 --- a/tests/unit/morphing/model/test_loader_provider.py +++ b/tests/unit/morphing/model/test_loader_provider.py @@ -10,6 +10,7 @@ from adaptix._internal.common import VarTuple from adaptix._internal.model_tools.definitions import ( Default, + DefaultValue, InputField, InputShape, NoDefault, @@ -1174,3 +1175,28 @@ def test_empty_list(debug_ctx, debug_trail, extra_policy, trail_select, strict_c ), lambda: loader('abc'), ) + + +def test_skipped_pos_optional_pos_field(debug_ctx, extra_policy): + loader_getter = make_loader_getter( + shape=shape( + TestField('a', ParamKind.POS_OR_KW, is_required=True), + TestField('b', ParamKind.POS_OR_KW, is_required=False, default=DefaultValue(10)), + TestField('c', ParamKind.POS_OR_KW, is_required=False, default=DefaultValue(20)), + ), + name_layout=InputNameLayout( + crown=InpDictCrown( + { + 'a': InpFieldCrown('a'), + 'c': InpFieldCrown('c'), + }, + extra_policy=ExtraForbid(), + ), + extra_move=None, + ), + debug_trail=DebugTrail.ALL, + debug_ctx=debug_ctx, + ) + loader = loader_getter() + + assert loader({'a': 1, 'c': 3}) == gauge(1, c=3)