-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add the first implementation (POC) of model conversion (the first tha…
…t I decided to commit)
- Loading branch information
Showing
23 changed files
with
926 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import ast | ||
from ast import AST, NodeTransformer | ||
from typing import Mapping | ||
|
||
|
||
class Substitutor(NodeTransformer): | ||
__slots__ = ('substitution', ) | ||
|
||
def __init__(self, substitution: Mapping[str, AST]): | ||
self._substitution = substitution | ||
|
||
def visit_Name(self, node: ast.Name): # pylint: disable=invalid-name | ||
if node.id in self._substitution: | ||
return self._substitution[node.id] | ||
return node | ||
|
||
|
||
def ast_substitute(template: str, **kwargs: AST) -> AST: | ||
substitution = {f"__{key}__": value for key, value in kwargs.items()} | ||
return Substitutor(substitution).generic_visit(ast.parse(template)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,41 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, Optional | ||
from typing import AbstractSet, Dict, Optional | ||
|
||
|
||
class ContextNamespace(ABC): | ||
@abstractmethod | ||
def add(self, name: str, value: object) -> None: | ||
... | ||
|
||
@abstractmethod | ||
def __contains__(self, item: str) -> bool: | ||
... | ||
|
||
|
||
class BuiltinContextNamespace(ContextNamespace): | ||
def __init__(self, namespace: Optional[Dict[str, object]] = None): | ||
__slots__ = ('dict', '_occupied') | ||
|
||
def __init__( | ||
self, | ||
namespace: Optional[Dict[str, object]] = None, | ||
occupied: Optional[AbstractSet[str]] = None, | ||
): | ||
if namespace is None: | ||
namespace = {} | ||
if occupied is None: | ||
occupied = set() | ||
|
||
self.dict = namespace | ||
self._occupied = occupied | ||
|
||
def add(self, name: str, value: object) -> None: | ||
if name in self._occupied: | ||
raise KeyError(f"Key {name} is duplicated") | ||
if name in self.dict: | ||
if value is self.dict[name]: | ||
return | ||
raise KeyError(f"Key {name} is duplicated") | ||
|
||
self.dict[name] = value | ||
|
||
def __contains__(self, item: str) -> bool: | ||
return item in self.dict or item in self._occupied |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Iterable | ||
|
||
from ..provider.essential import CannotProvide, Mediator | ||
from ..provider.loc_stack_filtering import LocStackChecker | ||
from ..provider.static_provider import StaticProvider, static_provision_action | ||
from .request_cls import BindingRequest, BindingResult, BindingSource, SourceCandidates | ||
|
||
|
||
class BindingProvider(StaticProvider, ABC): | ||
@static_provision_action | ||
@abstractmethod | ||
def _provide_binder(self, mediator: Mediator, request: BindingRequest) -> BindingResult: | ||
... | ||
|
||
|
||
def iterate_source_candidates(candidates: SourceCandidates) -> Iterable[BindingSource]: | ||
for source in reversed(candidates): | ||
if isinstance(source, tuple): | ||
yield from source | ||
else: | ||
yield source | ||
|
||
|
||
class SameNameBindingProvider(BindingProvider): | ||
def __init__(self, is_default: bool): | ||
self._is_default = is_default | ||
|
||
def _provide_binder(self, mediator: Mediator, request: BindingRequest) -> BindingResult: | ||
target_field_id = request.destination.last.id | ||
for source in iterate_source_candidates(request.sources): | ||
if source.last.id == target_field_id: | ||
return BindingResult(source=source, is_default=self._is_default) | ||
raise CannotProvide | ||
|
||
|
||
class MatchingBindingProvider(BindingProvider): | ||
def __init__(self, src_lsc: LocStackChecker, dst_lsc: LocStackChecker): | ||
self._src_lsc = src_lsc | ||
self._dst_lsc = dst_lsc | ||
|
||
def _provide_binder(self, mediator: Mediator, request: BindingRequest) -> BindingResult: | ||
if not self._dst_lsc.check_loc_stack(mediator, request.destination.to_loc_stack()): | ||
raise CannotProvide | ||
|
||
for source in iterate_source_candidates(request.sources): | ||
if self._src_lsc.check_loc_stack(mediator, source.to_loc_stack()): | ||
return BindingResult(source=source) | ||
raise CannotProvide |
Empty file.
175 changes: 175 additions & 0 deletions
175
src/adaptix/_internal/conversion/broaching/code_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import ast | ||
import itertools | ||
from abc import ABC, abstractmethod | ||
from ast import AST | ||
from collections import defaultdict | ||
from inspect import Signature | ||
from typing import DefaultDict, Mapping, Tuple, Union | ||
|
||
from ...code_tools.ast_templater import ast_substitute | ||
from ...code_tools.code_builder import CodeBuilder | ||
from ...code_tools.context_namespace import BuiltinContextNamespace, ContextNamespace | ||
from ...code_tools.utils import get_literal_expr | ||
from ...compat import compat_ast_unparse | ||
from ...model_tools.definitions import DescriptorAccessor, ItemAccessor | ||
from ...special_cases_optimization import as_is_stub | ||
from .definitions import ( | ||
AccessorElement, | ||
ConstantElement, | ||
FunctionElement, | ||
KeywordArg, | ||
ParameterElement, | ||
PositionalArg, | ||
UnpackIterable, | ||
UnpackMapping, | ||
) | ||
|
||
BroachingPlan = Union[ | ||
ParameterElement, | ||
ConstantElement, | ||
FunctionElement['BroachingPlan'], | ||
AccessorElement['BroachingPlan'], | ||
] | ||
|
||
|
||
class GenState: | ||
def __init__(self, ctx_namespace: ContextNamespace): | ||
self._ctx_namespace = ctx_namespace | ||
self._prefix_counter: DefaultDict[str, int] = defaultdict(lambda: 0) | ||
|
||
def register_next_id(self, prefix: str, obj: object) -> str: | ||
number = self._prefix_counter[prefix] | ||
self._prefix_counter[prefix] += 1 | ||
name = f"{prefix}_{number}" | ||
return self.register_mangled(name, obj) | ||
|
||
def register_mangled(self, base: str, obj: object) -> str: | ||
if base not in self._ctx_namespace: | ||
self._ctx_namespace.add(base, obj) | ||
return base | ||
|
||
for i in itertools.count(1): | ||
name = f'{base}_{i}' | ||
if name not in self._ctx_namespace: | ||
self._ctx_namespace.add(base, obj) | ||
return name | ||
raise RuntimeError | ||
|
||
|
||
class BroachingCodeGenerator(ABC): | ||
@abstractmethod | ||
def produce_code(self, closure_name: str, signature: Signature) -> Tuple[str, Mapping[str, object]]: | ||
... | ||
|
||
|
||
class BuiltinBroachingCodeGenerator(BroachingCodeGenerator): | ||
def __init__(self, plan: BroachingPlan): | ||
self._plan = plan | ||
|
||
def _create_state(self, ctx_namespace: ContextNamespace) -> GenState: | ||
return GenState( | ||
ctx_namespace=ctx_namespace, | ||
) | ||
|
||
def produce_code(self, closure_name: str, signature: Signature) -> Tuple[str, Mapping[str, object]]: | ||
builder = CodeBuilder() | ||
ctx_namespace = BuiltinContextNamespace(occupied=signature.parameters.keys()) | ||
state = self._create_state(ctx_namespace=ctx_namespace) | ||
|
||
ctx_namespace.add('_closure_signature', signature) | ||
no_types_signature = signature.replace( | ||
parameters=[param.replace(annotation=Signature.empty) for param in signature.parameters.values()], | ||
return_annotation=Signature.empty, | ||
) | ||
with builder(f'def {closure_name}{no_types_signature}:'): | ||
body = self._gen_plan_element_dispatch(state, self._plan) | ||
builder += 'return ' + compat_ast_unparse(body) | ||
|
||
builder += f'{closure_name}.__signature__ = _closure_signature' | ||
return builder.string(), ctx_namespace.dict | ||
|
||
def _gen_plan_element_dispatch(self, state: GenState, element: BroachingPlan) -> AST: | ||
if isinstance(element, ParameterElement): | ||
return self._gen_parameter_element(state, element) | ||
if isinstance(element, ConstantElement): | ||
return self._gen_constant_element(state, element) | ||
if isinstance(element, FunctionElement): | ||
return self._gen_function_element(state, element) | ||
if isinstance(element, AccessorElement): | ||
return self._gen_accessor_element(state, element) | ||
raise TypeError | ||
|
||
def _gen_parameter_element(self, state: GenState, element: ParameterElement) -> AST: | ||
return ast.Name(id=element.name, ctx=ast.Load()) | ||
|
||
def _gen_constant_element(self, state: GenState, element: ConstantElement) -> AST: | ||
expr = get_literal_expr(element.value) | ||
if expr is not None: | ||
return ast.parse(expr) | ||
|
||
name = state.register_next_id('constant', element.value) | ||
return ast.Name(id=name, ctx=ast.Load()) | ||
|
||
def _gen_function_element(self, state: GenState, element: FunctionElement[BroachingPlan]) -> AST: | ||
if ( | ||
element.func == as_is_stub | ||
and len(element.args) == 1 | ||
and isinstance(element.args[0], PositionalArg) | ||
): | ||
return self._gen_plan_element_dispatch(state, element.args[0].element) | ||
|
||
if getattr(element.func, '__name__', None) is not None: | ||
name = state.register_mangled(element.func.__name__, element.func) | ||
else: | ||
name = state.register_next_id('func', element.func) | ||
|
||
args = [] | ||
keywords = [] | ||
for arg in element.args: | ||
if isinstance(arg, PositionalArg): | ||
sub_ast = self._gen_plan_element_dispatch(state, arg.element) | ||
args.append(sub_ast) | ||
elif isinstance(arg, KeywordArg): | ||
sub_ast = self._gen_plan_element_dispatch(state, arg.element) | ||
keywords.append(ast.keyword(arg=arg.key, value=sub_ast)) | ||
elif isinstance(arg, UnpackMapping): | ||
sub_ast = self._gen_plan_element_dispatch(state, arg.element) | ||
keywords.append(ast.keyword(value=sub_ast)) | ||
elif isinstance(arg, UnpackIterable): | ||
sub_ast = self._gen_plan_element_dispatch(state, arg.element) | ||
args.append(ast.Starred(value=sub_ast, ctx=ast.Load())) | ||
else: | ||
raise TypeError | ||
|
||
return ast.Call( | ||
func=ast.Name(name, ast.Load()), | ||
args=args, | ||
keywords=keywords, | ||
) | ||
|
||
def _gen_accessor_element(self, state: GenState, element: AccessorElement[BroachingPlan]) -> AST: | ||
target_expr = self._gen_plan_element_dispatch(state, element.target) | ||
if isinstance(element.accessor, DescriptorAccessor): | ||
if element.accessor.attr_name.isidentifier(): | ||
return ast_substitute( | ||
f'__target_expr__.{element.accessor.attr_name}', | ||
target_expr=target_expr, | ||
) | ||
return ast_substitute( | ||
f"getattr(__target_expr__, {element.accessor.attr_name!r})", | ||
target_expr=target_expr, | ||
) | ||
|
||
if isinstance(element.accessor, ItemAccessor): | ||
literal_expr = get_literal_expr(element.accessor.key) | ||
if literal_expr is not None: | ||
return ast_substitute( | ||
f"__target_expr__[{literal_expr!r}]", | ||
target_expr=target_expr, | ||
) | ||
|
||
name = state.register_next_id('accessor', element.accessor.getter) | ||
return ast_substitute( | ||
f"{name}(__target_expr__)", | ||
target_expr=target_expr, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from abc import ABC | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Generic, TypeVar, Union | ||
|
||
from adaptix._internal.common import VarTuple | ||
from adaptix._internal.model_tools.definitions import Accessor | ||
|
||
|
||
class BasePlanElement(ABC): | ||
pass | ||
|
||
|
||
PlanT = TypeVar('PlanT', bound=BasePlanElement) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ParameterElement(BasePlanElement): | ||
name: str | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ConstantElement(BasePlanElement): | ||
value: Any | ||
|
||
|
||
@dataclass(frozen=True) | ||
class PositionalArg(Generic[PlanT]): | ||
element: PlanT | ||
|
||
|
||
@dataclass(frozen=True) | ||
class KeywordArg(Generic[PlanT]): | ||
key: str | ||
element: PlanT | ||
|
||
|
||
@dataclass(frozen=True) | ||
class UnpackIterable(Generic[PlanT]): | ||
element: PlanT | ||
|
||
|
||
@dataclass(frozen=True) | ||
class UnpackMapping(Generic[PlanT]): | ||
element: PlanT | ||
|
||
|
||
@dataclass(frozen=True) | ||
class FunctionElement(BasePlanElement, Generic[PlanT]): | ||
func: Callable[..., Any] | ||
args: VarTuple[ | ||
Union[ | ||
PositionalArg[PlanT], | ||
KeywordArg[PlanT], | ||
UnpackIterable[PlanT], | ||
UnpackMapping[PlanT], | ||
] | ||
] | ||
|
||
|
||
@dataclass(frozen=True) | ||
class AccessorElement(BasePlanElement, Generic[PlanT]): | ||
target: PlanT | ||
accessor: Accessor |
Oops, something went wrong.