Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed parameter shuffling on skipping optional field #234

Merged
merged 1 commit into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/fragments/229.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed parameter shuffling on skipping optional field
67 changes: 20 additions & 47 deletions src/adaptix/_internal/morphing/model/basic_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,27 @@
import re
import string
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from typing import Any, Callable, Collection, Container, Dict, Iterable, List, Mapping, Set, Tuple, TypeVar, Union
from dataclasses import dataclass
from typing import (
AbstractSet,
Any,
Callable,
Collection,
Container,
Dict,
Iterable,
List,
Mapping,
Set,
Tuple,
TypeVar,
Union,
)

from ...code_tools.code_builder import CodeBuilder
from ...code_tools.compiler import ClosureCompiler
from ...code_tools.utils import get_literal_expr
from ...model_tools.definitions import InputField, InputShape, OutputField, OutputShape
from ...model_tools.definitions import InputField, OutputField
from ...provider.essential import CannotProvide, Mediator
from ...provider.request_cls import LocatedRequest, LocStack, TypeHintLoc
from ...provider.static_provider import StaticProvider, static_provision_action
Expand Down Expand Up @@ -123,17 +137,17 @@ def _collect_used_direct_fields(crown: BaseCrown) -> Set[str]:
return used_set


def get_skipped_fields(shape: BaseShape, name_layout: BaseNameLayout) -> Collection[str]:
def get_skipped_fields(shape: BaseShape, name_layout: BaseNameLayout) -> AbstractSet[str]:
used_direct_fields = _collect_used_direct_fields(name_layout.crown)
if isinstance(name_layout.extra_move, ExtraTargets):
extra_targets = name_layout.extra_move.fields
else:
extra_targets = ()

return [
return {
field.id for field in shape.fields
if field.id not in used_direct_fields and field.id not in extra_targets
]
}


def _inner_get_extra_targets_at_crown(extra_targets: Container[str], crown: BaseCrown) -> Collection[str]:
Expand Down Expand Up @@ -196,47 +210,6 @@ def get_wild_extra_targets(shape: BaseShape, extra_move: Union[InpExtraMove, Out
]


def strip_input_shape_fields(shape: InputShape, skipped_fields: Collection[str]) -> InputShape:
skipped_required_fields = [
field.id
for field in shape.fields
if field.is_required and field.id in skipped_fields
]
if skipped_required_fields:
raise ValueError(
f"Required fields {skipped_required_fields} are skipped"
)
return replace(
shape,
fields=tuple(
field for field in shape.fields
if field.id not in skipped_fields
),
params=tuple(
param for param in shape.params
if param.field_id not in skipped_fields
),
overriden_types=frozenset(
field.id for field in shape.fields
if field.id not in skipped_fields
),
)


def strip_output_shape_fields(shape: OutputShape, skipped_fields: Collection[str]) -> OutputShape:
return replace(
shape,
fields=tuple(
field for field in shape.fields
if field.id not in skipped_fields
),
overriden_types=frozenset(
field.id for field in shape.fields
if field.id not in skipped_fields
)
)


class NameSanitizer:
_BAD_CHARS = re.compile(r'\W')
_TRANSLATE_MAP = str.maketrans({'.': '_', '[': '_'})
Expand Down
5 changes: 4 additions & 1 deletion src/adaptix/_internal/morphing/model/dumper_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from ...special_cases_optimization import as_is_stub, get_default_clause
from ...struct_trail import append_trail, extend_trail, render_trail_as_note
from .basic_gen import ModelDumperGen
from .basic_gen import ModelDumperGen, get_skipped_fields
from .crown_definitions import (
CrownPath,
CrownPathElem,
Expand Down Expand Up @@ -132,7 +132,10 @@ def produce_code(self, closure_name: str) -> Tuple[str, Mapping[str, object]]:
body_builder("opt_fields = {}")
body_builder.empty_line()

skipped_fields = get_skipped_fields(self._shape, self._name_layout)
for field in self._shape.fields:
if field.id in skipped_fields:
continue
if not self._is_extra_target(field):
self._gen_field_extraction(
body_builder, namespace, field,
Expand Down
8 changes: 2 additions & 6 deletions src/adaptix/_internal/morphing/model/dumper_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
fetch_code_gen_hook,
get_extra_targets_at_crown,
get_optional_fields_at_list_crown,
get_skipped_fields,
get_wild_extra_targets,
strip_output_shape_fields,
)
from .crown_definitions import OutputNameLayout, OutputNameLayoutRequest
from .dumper_gen import BuiltinModelDumperGen
Expand All @@ -46,7 +44,7 @@ def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
def _fetch_model_dumper_gen(self, mediator: Mediator, request: DumperRequest) -> ModelDumperGen:
shape = self._fetch_shape(mediator, request)
name_layout = self._fetch_name_layout(mediator, request, shape)
shape = self._process_shape(shape, name_layout)
self._validate_params(shape, name_layout)

fields_dumpers = self._fetch_field_dumpers(mediator, request, shape)
debug_trail = mediator.mandatory_provide(DebugTrailRequest(loc_stack=request.loc_stack))
Expand Down Expand Up @@ -137,7 +135,7 @@ def _fetch_field_dumpers(
)
return {field.id: dumper for field, dumper in zip(shape.fields, dumpers)}

def _process_shape(self, shape: OutputShape, name_layout: OutputNameLayout) -> OutputShape:
def _validate_params(self, shape: OutputShape, name_layout: OutputNameLayout) -> None:
optional_fields_at_list_crown = get_optional_fields_at_list_crown(
{field.id: field for field in shape.fields},
name_layout.crown,
Expand All @@ -158,5 +156,3 @@ def _process_shape(self, shape: OutputShape, name_layout: OutputNameLayout) -> O
raise ValueError(
f"Extra targets {extra_targets_at_crown} are found at crown"
)

return strip_output_shape_fields(shape, get_skipped_fields(shape, name_layout))
18 changes: 14 additions & 4 deletions src/adaptix/_internal/morphing/model/loader_gen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections.abc
import contextlib
from dataclasses import dataclass
from typing import Dict, List, Mapping, Optional, Set, Tuple
from typing import AbstractSet, Dict, List, Mapping, Optional, Set, Tuple

from ...code_tools.cascade_namespace import BuiltinCascadeNamespace, CascadeNamespace
from ...code_tools.code_builder import CodeBuilder
Expand Down Expand Up @@ -180,6 +180,7 @@ def __init__(
debug_trail: DebugTrail,
strict_coercion: bool,
field_loaders: Mapping[str, Loader],
skipped_fields: AbstractSet[str],
model_identity: str,
props: ModelLoaderProps,
):
Expand All @@ -194,6 +195,7 @@ def __init__(
param.field_id: param for param in self._shape.params
}
self._field_loaders = field_loaders
self._skipped_fields = skipped_fields
self._model_identity = model_identity
self._props = props

Expand Down Expand Up @@ -303,21 +305,29 @@ def _gen_header(self, state: GenState):

state.builder.extend_above(header_builder)

def _gen_constructor_call(self, state: GenState) -> None:
def _gen_constructor_call(self, state: GenState) -> None: # noqa: CCR001
state.namespace.add_constant('constructor', self._shape.constructor)

constructor_builder = CodeBuilder()
has_skipped_params = False
with constructor_builder("constructor("):
for param in self._shape.params:
field = self._shape.fields_dict[param.field_id]

if field.id in self._skipped_fields:
has_skipped_params = True
continue
if self._is_packed_field(field):
continue

value = state.v_field(field)

if param.kind == ParamKind.KW_ONLY:
if param.kind == ParamKind.KW_ONLY or has_skipped_params:
constructor_builder(f"{param.name}={value},")
elif param.kind == ParamKind.POS_ONLY and has_skipped_params:
raise ValueError(
'Can not generate consistent constructor call,'
' positional-only parameter is skipped'
)
else:
constructor_builder(f"{value},")

Expand Down
37 changes: 26 additions & 11 deletions src/adaptix/_internal/morphing/model/loader_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Mapping
from typing import AbstractSet, Mapping

from adaptix._internal.provider.fields import input_field_to_loc_map

Expand All @@ -22,7 +22,6 @@
get_skipped_fields,
get_wild_extra_targets,
has_collect_policy,
strip_input_shape_fields,
)
from .crown_definitions import InputNameLayout, InputNameLayoutRequest

Expand Down Expand Up @@ -53,8 +52,8 @@ def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
def _fetch_model_loader_gen(self, mediator: Mediator, request: LoaderRequest) -> ModelLoaderGen:
shape = self._fetch_shape(mediator, request)
name_layout = self._fetch_name_layout(mediator, request, shape)
shape = self._process_shape(shape, name_layout)
self._validate_params(shape, name_layout)
skipped_fields = get_skipped_fields(shape, name_layout)
self._validate_params(shape, name_layout, skipped_fields)

field_loaders = self._fetch_field_loaders(mediator, request, shape)
strict_coercion = mediator.mandatory_provide(StrictCoercionRequest(loc_stack=request.loc_stack))
Expand All @@ -65,6 +64,7 @@ def _fetch_model_loader_gen(self, mediator: Mediator, request: LoaderRequest) ->
shape=shape,
name_layout=name_layout,
field_loaders=field_loaders,
skipped_fields=skipped_fields,
model_identity=self._fetch_model_identity(mediator, request, shape, name_layout),
)

Expand All @@ -88,6 +88,7 @@ def _create_model_loader_gen(
shape: InputShape,
name_layout: InputNameLayout,
field_loaders: Mapping[str, Loader],
skipped_fields: AbstractSet[str],
model_identity: str,
) -> ModelLoaderGen:
return BuiltinModelLoaderGen(
Expand All @@ -96,6 +97,7 @@ def _create_model_loader_gen(
debug_trail=debug_trail,
strict_coercion=strict_coercion,
field_loaders=field_loaders,
skipped_fields=skipped_fields,
model_identity=model_identity,
props=self._props,
)
Expand Down Expand Up @@ -151,15 +153,22 @@ def _fetch_field_loaders(
)
return {field.id: loader for field, loader in zip(shape.fields, loaders)}

def _process_shape(self, shape: InputShape, name_layout: InputNameLayout) -> InputShape:
wild_extra_targets = get_wild_extra_targets(shape, name_layout.extra_move)
if wild_extra_targets:
def _validate_params(
self,
shape: InputShape,
name_layout: InputNameLayout,
skipped_fields: AbstractSet[str],
) -> None:
skipped_required_fields = [
field.id
for field in shape.fields
if field.is_required and field.id in skipped_fields
]
if skipped_required_fields:
raise ValueError(
f"ExtraTargets {wild_extra_targets} are attached to non-existing fields"
f"Required fields {skipped_required_fields} are skipped"
)
return strip_input_shape_fields(shape, get_skipped_fields(shape, name_layout))

def _validate_params(self, processed_shape: InputShape, name_layout: InputNameLayout) -> None:
if name_layout.extra_move is None and has_collect_policy(name_layout.crown):
raise ValueError(
"Cannot create loader that collect extra data"
Expand All @@ -173,14 +182,20 @@ def _validate_params(self, processed_shape: InputShape, name_layout: InputNameLa
)

optional_fields_at_list_crown = get_optional_fields_at_list_crown(
{field.id: field for field in processed_shape.fields},
{field.id: field for field in shape.fields},
name_layout.crown,
)
if optional_fields_at_list_crown:
raise ValueError(
f"Optional fields {optional_fields_at_list_crown} are found at list crown"
)

wild_extra_targets = get_wild_extra_targets(shape, name_layout.extra_move)
if wild_extra_targets:
raise ValueError(
f"ExtraTargets {wild_extra_targets} are attached to non-existing fields"
)


class InlinedShapeModelLoaderProvider(ModelLoaderProvider):
def __init__(
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/morphing/model/test_loader_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from adaptix._internal.common import VarTuple
from adaptix._internal.model_tools.definitions import (
Default,
DefaultValue,
InputField,
InputShape,
NoDefault,
Expand Down Expand Up @@ -1174,3 +1175,28 @@ def test_empty_list(debug_ctx, debug_trail, extra_policy, trail_select, strict_c
),
lambda: loader('abc'),
)


def test_skipped_pos_optional_pos_field(debug_ctx, extra_policy):
loader_getter = make_loader_getter(
shape=shape(
TestField('a', ParamKind.POS_OR_KW, is_required=True),
TestField('b', ParamKind.POS_OR_KW, is_required=False, default=DefaultValue(10)),
TestField('c', ParamKind.POS_OR_KW, is_required=False, default=DefaultValue(20)),
),
name_layout=InputNameLayout(
crown=InpDictCrown(
{
'a': InpFieldCrown('a'),
'c': InpFieldCrown('c'),
},
extra_policy=ExtraForbid(),
),
extra_move=None,
),
debug_trail=DebugTrail.ALL,
debug_ctx=debug_ctx,
)
loader = loader_getter()

assert loader({'a': 1, 'c': 3}) == gauge(1, c=3)
Loading