Skip to content

Commit

Permalink
refactor functional of enums in LiteralProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
andiserg committed Jan 29, 2024
1 parent 346bf5b commit 9cc3999
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 44 deletions.
123 changes: 81 additions & 42 deletions src/adaptix/_internal/morphing/generic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from os import PathLike
from pathlib import Path
from typing import Any, Callable, Collection, Iterable, Literal, Optional, Type, Union
from typing import Any, Collection, Dict, Iterable, Literal, Sequence, Type, Union

from ..common import Dumper, Loader
from ..compat import CompatExceptionGroup
Expand All @@ -17,6 +17,7 @@
GenericParamLoc,
LocatedRequest,
LocMap,
LocStack,
StrictCoercionRequest,
TypeHintLoc,
get_type_from_request,
Expand Down Expand Up @@ -99,51 +100,78 @@ def _get_allowed_values_collection(self, args: Collection) -> Collection:
return set(args)
return tuple(args)

def _get_enum_loader(
self, mediator: Mediator, request: LoaderRequest, enum_class: Type[Enum]
) -> Callable[[Any], Enum]:
return mediator.mandatory_provide(
def _get_allowed_values_repr(self, args: Collection, mediator: Mediator, lock_stack: LocStack) -> Collection:
enum_cases = [arg for arg in args if isinstance(arg, Enum)]
if not enum_cases:
return set(args)

literal_dumper = self._provide_dumper(mediator, DumperRequest(lock_stack))
return {literal_dumper(arg) if isinstance(arg, Enum) else arg for arg in args}

def _fetch_enum_loaders(
self, mediator: Mediator, request: LoaderRequest, enum_classes: Iterable[Type[Enum]]
) -> Iterable[Loader[Enum]]:
requests = [
LoaderRequest(
loc_stack=request.loc_stack.append_with(
LocMap(
TypeHintLoc(type=enum_class),
TypeHintLoc(type=enum_cls),
)
)
),
lambda x: f'Cannot create loader for {enum_class}. Loader for literal cannot be created',
) for enum_cls in enum_classes
]
return mediator.mandatory_provide_by_iterable(
requests,
lambda: 'Cannot create loaders for enum. Loader for literal cannot be created',
)

def _get_enum_dumper(
self, mediator: Mediator, request: DumperRequest, enum_class: Type[Enum]
) -> Callable[[Any], Enum]:
return mediator.mandatory_provide(
def _fetch_enum_dumpers(
self, mediator: Mediator, request: DumperRequest, enum_classes: Iterable[Type[Enum]]
) -> Dict[Type[Enum], Dumper[Enum]]:
requests = [
DumperRequest(
loc_stack=request.loc_stack.append_with(
LocMap(
TypeHintLoc(type=enum_class),
TypeHintLoc(type=enum_cls),
)
)
),
lambda x: f'Cannot create dumper for {enum_class}. Dumper for literal cannot be created',
) for enum_cls in enum_classes
]
dumpers = mediator.mandatory_provide_by_iterable(
requests,
lambda: 'Cannot create loaders for enum. Loader for literal cannot be created',
)
return dict(zip(enum_classes, dumpers))

def _literal_loader_with_enum(
self, basic_loader: Loader, enum_loaders: Sequence[Loader[Enum]], allowed_values: Collection
) -> Loader:
if not enum_loaders:
return basic_loader

def validate_enum(enum_value):
return enum_value is not None and enum_value in allowed_values

def _combined_enums_loader(self, data, loaders: Collection, allowed_values: Collection) -> Optional[Enum]:
for loader in loaders:
def process_enum(data, loader):
try:
result = loader(data)
if result in allowed_values:
return result
except BadVariantError:
pass
return None

def _with_enum_loader(self, func: Callable, loaders: Collection, allowed_values: Collection):
def wrapped(data):
enum_data = self._combined_enums_loader(data, loaders, allowed_values)
if enum_data:
return enum_data
return func(data)
return wrapped
return loader(data)
except LoadError:
return None

def wrapped_loader(data):
for loader in enum_loaders:
enum_value = process_enum(data, loader)
if validate_enum(enum_value):
return enum_value
return basic_loader(data)

def wrapped_loader_with_single_enum(data):
enum_value = process_enum(data, enum_loaders[0])
if validate_enum(enum_value):
return enum_value
return basic_loader(data)

return wrapped_loader_with_single_enum if len(enum_loaders) == 1 else wrapped_loader

def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
norm = try_normalize_type(get_type_from_request(request))
Expand All @@ -152,8 +180,10 @@ def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
cleaned_args = [strip_annotated(arg) for arg in norm.args]

enum_cases = [arg for arg in cleaned_args if isinstance(arg, Enum)]
enum_loaders = [self._get_enum_loader(mediator, request, type(case)) for case in enum_cases]
with_enum_loader = self._with_enum_loader
enum_loaders = list(
self._fetch_enum_loaders(mediator, request, [type(case) for case in enum_cases])
) if enum_cases else []
allowed_values_repr = self._get_allowed_values_repr(cleaned_args, mediator, request.loc_stack)

if strict_coercion and any(
isinstance(arg, bool) or _is_exact_zero_or_one(arg)
Expand All @@ -162,42 +192,51 @@ def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
allowed_values_with_types = self._get_allowed_values_collection(
[(type(el), el) for el in cleaned_args]
)
allowed_values_repr = set(cleaned_args)

# since True == 1 and False == 0
def literal_loader_sc(data):
if (type(data), data) in allowed_values_with_types:
return data
raise BadVariantError(allowed_values_repr, data)

return with_enum_loader(
return self._literal_loader_with_enum(
literal_loader_sc, enum_loaders, allowed_values_with_types
) if enum_cases else literal_loader_sc
)

allowed_values = self._get_allowed_values_collection(cleaned_args)
allowed_values_repr = set(cleaned_args)

def literal_loader(data):
if data in allowed_values:
return data

raise BadVariantError(allowed_values_repr, data)

return with_enum_loader(literal_loader, enum_loaders, allowed_values) if enum_cases else literal_loader
return self._literal_loader_with_enum(literal_loader, enum_loaders, allowed_values)

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
norm = try_normalize_type(get_type_from_request(request))
cleaned_args = [strip_annotated(arg) for arg in norm.args]
enum_cases = [arg for arg in cleaned_args if isinstance(arg, Enum)]
enum_dumper_factory = self._get_enum_dumper

if not enum_cases:
return as_is_stub

enum_dumpers = self._fetch_enum_dumpers(
mediator, request, [type(case) for case in enum_cases]
)

def literal_dumper(data):
if isinstance(data, Enum):
enum_dumper = enum_dumper_factory(mediator, request, type(data))
return enum_dumpers[type(data)](data)
return data

enum_dumper = list(enum_dumpers.values())[0]

def literal_dumper_with_single_enum(data):
if isinstance(data, Enum):
return enum_dumper(data)
return data

return literal_dumper if enum_cases else as_is_stub
return literal_dumper_with_single_enum if len(enum_cases) == 1 else literal_dumper


@for_predicate(Union)
Expand Down
48 changes: 46 additions & 2 deletions tests/unit/morphing/generic_provider/test_literal_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tests_helpers import TestRetort, raises_exc

from adaptix._internal.morphing.enum_provider import EnumExactValueProvider
from adaptix._internal.morphing.generic_provider import LiteralProvider
from adaptix._internal.morphing.generic_provider import LiteralProvider, UnionProvider
from adaptix._internal.morphing.load_error import BadVariantError


Expand All @@ -15,7 +15,8 @@ def retort():
return TestRetort(
recipe=[
LiteralProvider(),
EnumExactValueProvider()
EnumExactValueProvider(),
UnionProvider()
]
)

Expand Down Expand Up @@ -102,6 +103,17 @@ class Enum2(Enum):
CASE1 = 1
CASE2 = 2

loader = retort.replace(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
).get_loader(
Literal["a", Enum1.CASE1, 5]
)

assert loader("a") == "a"
assert loader(1) == Enum1.CASE1
assert loader(5) == 5

loader = retort.replace(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
Expand All @@ -112,3 +124,35 @@ class Enum2(Enum):
assert loader(1) == Enum1.CASE1
assert loader(2) == Enum2.CASE2
assert loader(10) == 10


def test_dumper_with_enums(retort, strict_coercion, debug_trail):
class Enum1(Enum):
CASE1 = 1
CASE2 = 2

class Enum2(Enum):
CASE1 = 1
CASE2 = 2

dumper = retort.replace(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
).get_dumper(
Literal["a", Enum1.CASE1, 5]
)

assert dumper("a") == "a"
assert dumper(Enum1.CASE1) == 1
assert dumper(5) == 5

dumper = retort.replace(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
).get_dumper(
Literal[Enum1.CASE1, Enum2.CASE2, 10]
)

assert dumper(Enum1.CASE1) == 1
assert dumper(Enum1.CASE2) == 2
assert dumper(10) == 10

0 comments on commit 9cc3999

Please sign in to comment.