Skip to content

Commit

Permalink
Merge pull request #244 from reagento/fix/lambda-coercer
Browse files Browse the repository at this point in the history
Fix SyntaxError with lambda in coercer (#243)
  • Loading branch information
zhPavel authored Mar 2, 2024
2 parents ce12803 + 4d61d94 commit 79ffaa5
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 27 deletions.
1 change: 1 addition & 0 deletions docs/changelog/fragments/243.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix SyntaxError with lambda in :func:`.coercer`
21 changes: 21 additions & 0 deletions src/adaptix/_internal/code_tools/name_sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import re
import string
from abc import ABC, abstractmethod


class NameSanitizer(ABC):
@abstractmethod
def sanitize(self, name: str) -> str:
...


class BuiltinNameSanitizer(NameSanitizer):
_BAD_CHARS = re.compile(r'\W')
_TRANSLATE_MAP = str.maketrans({'.': '_', '[': '_'})

def sanitize(self, name: str) -> str:
if name == "":
return ""

first_letter = name[0] if name[0] in string.ascii_letters else '_'
return first_letter + self._BAD_CHARS.sub('', name[1:].translate(self._TRANSLATE_MAP))
9 changes: 7 additions & 2 deletions src/adaptix/_internal/conversion/broaching/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ...code_tools.ast_templater import ast_substitute
from ...code_tools.cascade_namespace import BuiltinCascadeNamespace, CascadeNamespace
from ...code_tools.code_builder import CodeBuilder
from ...code_tools.name_sanitizer import NameSanitizer
from ...code_tools.utils import get_literal_expr
from ...compat import compat_ast_unparse
from ...model_tools.definitions import DescriptorAccessor, ItemAccessor
Expand All @@ -34,8 +35,9 @@


class GenState:
def __init__(self, namespace: CascadeNamespace):
def __init__(self, namespace: CascadeNamespace, name_sanitizer: NameSanitizer):
self._namespace = namespace
self._name_sanitizer = name_sanitizer
self._prefix_counter: DefaultDict[str, int] = defaultdict(lambda: 0)

def register_next_id(self, prefix: str, obj: object) -> str:
Expand All @@ -45,6 +47,7 @@ def register_next_id(self, prefix: str, obj: object) -> str:
return self.register_mangled(name, obj)

def register_mangled(self, base: str, obj: object) -> str:
base = self._name_sanitizer.sanitize(base)
if self._namespace.try_add_constant(base, obj):
return base

Expand All @@ -67,12 +70,14 @@ def produce_code(


class BuiltinBroachingCodeGenerator(BroachingCodeGenerator):
def __init__(self, plan: BroachingPlan):
def __init__(self, plan: BroachingPlan, name_sanitizer: NameSanitizer):
self._plan = plan
self._name_sanitizer = name_sanitizer

def _create_state(self, namespace: CascadeNamespace) -> GenState:
return GenState(
namespace=namespace,
name_sanitizer=self._name_sanitizer,
)

def produce_code(
Expand Down
7 changes: 4 additions & 3 deletions src/adaptix/_internal/conversion/converter_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Iterable, List, Mapping, Optional, Sequence, Tuple, cast, final

from ..code_tools.compiler import BasicClosureCompiler, ClosureCompiler
from ..code_tools.name_sanitizer import BuiltinNameSanitizer, NameSanitizer
from ..common import Converter, TypeHint
from ..conversion.broaching.code_generator import BroachingCodeGenerator, BroachingPlan, BuiltinBroachingCodeGenerator
from ..conversion.broaching.definitions import (
Expand All @@ -24,7 +25,7 @@
UnlinkedOptionalPolicyRequest,
)
from ..model_tools.definitions import BaseField, DefaultValue, InputField, InputShape, NoDefault, OutputShape, ParamKind
from ..morphing.model.basic_gen import NameSanitizer, compile_closure_with_globals_capturing, fetch_code_gen_hook
from ..morphing.model.basic_gen import compile_closure_with_globals_capturing, fetch_code_gen_hook
from ..provider.essential import CannotProvide, Mediator, mandatory_apply_by_iterable
from ..provider.fields import base_field_to_loc_map, input_field_to_loc_map
from ..provider.request_cls import LocMap, LocStack, TypeHintLoc
Expand All @@ -45,7 +46,7 @@ def _provide_converter(self, mediator: Mediator, request: ConverterRequest) -> C


class BuiltinConverterProvider(ConverterProvider):
def __init__(self, *, name_sanitizer: NameSanitizer = NameSanitizer()):
def __init__(self, *, name_sanitizer: NameSanitizer = BuiltinNameSanitizer()):
self._name_sanitizer = name_sanitizer

def _provide_converter(self, mediator: Mediator, request: ConverterRequest) -> Converter:
Expand Down Expand Up @@ -161,7 +162,7 @@ def _get_compiler(self) -> ClosureCompiler:
return BasicClosureCompiler()

def _create_broaching_code_gen(self, plan: BroachingPlan) -> BroachingCodeGenerator:
return BuiltinBroachingCodeGenerator(plan=plan)
return BuiltinBroachingCodeGenerator(plan=plan, name_sanitizer=self._name_sanitizer)

def _fetch_linkings(
self,
Expand Down
14 changes: 0 additions & 14 deletions src/adaptix/_internal/morphing/model/basic_gen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import itertools
import re
import string
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
Expand Down Expand Up @@ -210,18 +208,6 @@ def get_wild_extra_targets(shape: BaseShape, extra_move: Union[InpExtraMove, Out
]


class NameSanitizer:
_BAD_CHARS = re.compile(r'\W')
_TRANSLATE_MAP = str.maketrans({'.': '_', '[': '_'})

def sanitize(self, name: str) -> str:
if name == "":
return ""

first_letter = name[0] if name[0] in string.ascii_letters else '_'
return first_letter + self._BAD_CHARS.sub('', name[1:].translate(self._TRANSLATE_MAP))


def compile_closure_with_globals_capturing(
compiler: ClosureCompiler,
code_gen_hook: CodeGenHook,
Expand Down
4 changes: 2 additions & 2 deletions src/adaptix/_internal/morphing/model/dumper_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from adaptix._internal.provider.fields import output_field_to_loc_map

from ...code_tools.compiler import BasicClosureCompiler, ClosureCompiler
from ...code_tools.name_sanitizer import BuiltinNameSanitizer, NameSanitizer
from ...common import Dumper
from ...definitions import DebugTrail
from ...model_tools.definitions import OutputShape
Expand All @@ -13,7 +14,6 @@
from ..request_cls import DumperRequest
from .basic_gen import (
ModelDumperGen,
NameSanitizer,
compile_closure_with_globals_capturing,
fetch_code_gen_hook,
get_extra_targets_at_crown,
Expand All @@ -25,7 +25,7 @@


class ModelDumperProvider(DumperProvider):
def __init__(self, *, name_sanitizer: NameSanitizer = NameSanitizer()):
def __init__(self, *, name_sanitizer: NameSanitizer = BuiltinNameSanitizer()):
self._name_sanitizer = name_sanitizer

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
Expand Down
6 changes: 3 additions & 3 deletions src/adaptix/_internal/morphing/model/loader_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from adaptix._internal.provider.fields import input_field_to_loc_map

from ...code_tools.compiler import BasicClosureCompiler, ClosureCompiler
from ...code_tools.name_sanitizer import BuiltinNameSanitizer, NameSanitizer
from ...common import Loader
from ...definitions import DebugTrail
from ...model_tools.definitions import InputShape
Expand All @@ -14,7 +15,6 @@
from ..request_cls import LoaderRequest
from .basic_gen import (
ModelLoaderGen,
NameSanitizer,
compile_closure_with_globals_capturing,
fetch_code_gen_hook,
get_extra_targets_at_crown,
Expand All @@ -30,7 +30,7 @@ class ModelLoaderProvider(LoaderProvider):
def __init__(
self,
*,
name_sanitizer: NameSanitizer = NameSanitizer(),
name_sanitizer: NameSanitizer = BuiltinNameSanitizer(),
props: ModelLoaderProps = ModelLoaderProps(),
):
self._name_sanitizer = name_sanitizer
Expand Down Expand Up @@ -201,7 +201,7 @@ class InlinedShapeModelLoaderProvider(ModelLoaderProvider):
def __init__(
self,
*,
name_sanitizer: NameSanitizer = NameSanitizer(),
name_sanitizer: NameSanitizer = BuiltinNameSanitizer(),
props: ModelLoaderProps = ModelLoaderProps(),
shape: InputShape,
):
Expand Down
15 changes: 12 additions & 3 deletions tests/integration/conversion/test_coercer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from typing import Any

import pytest

from adaptix.conversion import coercer, get_converter, impl_converter

from .local_helpers import FactoryWay


def test_simple(src_model_spec, dst_model_spec, factory_way):
@pytest.mark.parametrize(
'func',
[
pytest.param(int, id='int'),
pytest.param(lambda x: int(x), id='lambda'),
],
)
def test_simple(src_model_spec, dst_model_spec, factory_way, func):
@src_model_spec.decorator
class SourceModel(*src_model_spec.bases):
field1: str
Expand All @@ -17,11 +26,11 @@ class DestModel(*dst_model_spec.bases):
field2: int

if factory_way == FactoryWay.IMPL_CONVERTER:
@impl_converter(recipe=[coercer(str, int, func=int)])
@impl_converter(recipe=[coercer(str, int, func=func)])
def convert(a: SourceModel) -> DestModel:
...
else:
convert = get_converter(SourceModel, DestModel, recipe=[coercer(str, int, func=int)])
convert = get_converter(SourceModel, DestModel, recipe=[coercer(str, int, func=func)])

assert convert(SourceModel(field1='1', field2='2')) == DestModel(field1=1, field2=2)

Expand Down

0 comments on commit 79ffaa5

Please sign in to comment.