Skip to content

Commit

Permalink
refactor code gen
Browse files Browse the repository at this point in the history
  • Loading branch information
zhPavel committed Jan 27, 2024
1 parent 44799a1 commit 0b9cde5
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 65 deletions.
10 changes: 0 additions & 10 deletions src/adaptix/_internal/code_generator.py

This file was deleted.

24 changes: 16 additions & 8 deletions src/adaptix/_internal/morphing/model/basic_gen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import re
import string
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from typing import Any, Callable, Collection, Container, Dict, Iterable, List, Mapping, Set, Tuple, TypeVar, Union

Expand Down Expand Up @@ -244,11 +245,10 @@ def sanitize(self, name: str) -> str:
def compile_closure_with_globals_capturing(
compiler: ClosureCompiler,
code_gen_hook: CodeGenHook,
namespace: Dict[str, object],
body_builders: Iterable[CodeBuilder],
namespace: Mapping[str, object],
*,
closure_name: str,
closure_params: str,
closure_code: str,
file_name: str,
):
builder = CodeBuilder()
Expand All @@ -264,11 +264,7 @@ def compile_closure_with_globals_capturing(
builder += f"{name} = {value_literal}"

builder.empty_line()

with builder(f"def {closure_name}({closure_params}):"):
for body_builder in body_builders:
builder.extend(body_builder)

builder += closure_code
builder += f"return {closure_name}"

code_gen_hook(
Expand Down Expand Up @@ -300,3 +296,15 @@ def has_collect_policy(crown: InpCrown) -> bool:
if isinstance(crown, (InpFieldCrown, InpNoneCrown)):
return False
raise TypeError


class ModelLoaderGen(ABC):
@abstractmethod
def produce_code(self, closure_name: str) -> Tuple[str, Mapping[str, object]]:
...


class ModelDumperGen(ABC):
@abstractmethod
def produce_code(self, closure_name: str) -> Tuple[str, Mapping[str, object]]:
...
33 changes: 19 additions & 14 deletions src/adaptix/_internal/morphing/model/dumper_gen.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import contextlib
from string import Template
from typing import Dict, Mapping, NamedTuple
from typing import Dict, Mapping, NamedTuple, Tuple

from ...code_generator import CodeGenerator
from ...code_tools.code_builder import CodeBuilder
from ...code_tools.context_namespace import ContextNamespace
from ...code_tools.context_namespace import BuiltinContextNamespace, ContextNamespace
from ...code_tools.utils import get_literal_expr, get_literal_from_factory, is_singleton
from ...common import Dumper
from ...compat import CompatExceptionGroup
Expand All @@ -20,6 +19,7 @@
)
from ...special_cases_optimization import as_is_stub, get_default_clause
from ...struct_trail import append_trail, extend_trail, render_trail_as_note
from .basic_gen import ModelDumperGen
from .crown_definitions import (
CrownPath,
CrownPathElem,
Expand Down Expand Up @@ -93,7 +93,7 @@ class ElementExpr(NamedTuple):
can_inline: bool


class ModelDumperGen(CodeGenerator):
class BuiltinModelDumperGen(ModelDumperGen):
def __init__(
self,
shape: OutputShape,
Expand All @@ -114,35 +114,36 @@ def __init__(
self._id_to_field: Dict[str, OutputField] = {field.id: field for field in self._shape.fields}
self._model_identity = model_identity

def produce_code(self, ctx_namespace: ContextNamespace) -> CodeBuilder:
builder = CodeBuilder()
def produce_code(self, closure_name: str) -> Tuple[str, Mapping[str, object]]:
body_builder = CodeBuilder()

ctx_namespace = BuiltinContextNamespace()
ctx_namespace.add('CompatExceptionGroup', CompatExceptionGroup)
ctx_namespace.add("append_trail", append_trail)
ctx_namespace.add("extend_trail", extend_trail)
for field_id, dumper in self._fields_dumpers.items():
ctx_namespace.add(self._v_dumper(self._id_to_field[field_id]), dumper)

if self._debug_trail == DebugTrail.ALL:
builder('errors = []')
builder.empty_line()
body_builder('errors = []')
body_builder.empty_line()

if any(field.is_optional for field in self._shape.fields):
builder("opt_fields = {}")
builder.empty_line()
body_builder("opt_fields = {}")
body_builder.empty_line()

for field in self._shape.fields:
if not self._is_extra_target(field):
self._gen_field_extraction(
builder, ctx_namespace, field,
body_builder, ctx_namespace, field,
on_access_error="pass",
on_access_ok_req=f"{self._v_field(field)} = $expr",
on_access_ok_opt=f"opt_fields[{field.id!r}] = $expr",
)

self._gen_extra_extraction(builder, ctx_namespace)
self._gen_extra_extraction(body_builder, ctx_namespace)

state = self._create_state(builder, ctx_namespace)
state = self._create_state(body_builder, ctx_namespace)

if not self._gen_root_crown_dispatch(state, self._name_layout.crown):
raise TypeError
Expand All @@ -155,7 +156,11 @@ def produce_code(self, ctx_namespace: ContextNamespace) -> CodeBuilder:
)

self._gen_header(state)
return builder

builder = CodeBuilder()
with builder(f'def {closure_name}(data):'):
builder.extend(body_builder)
return builder.string(), ctx_namespace.dict

def _is_extra_target(self, field: OutputField) -> bool:
return field.id in self._extra_targets
Expand Down
25 changes: 12 additions & 13 deletions src/adaptix/_internal/morphing/model/dumper_provider.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import Mapping

from ...code_generator import CodeGenerator
from adaptix._internal.provider.fields import output_field_to_loc_map

from ...code_tools.compiler import BasicClosureCompiler
from ...code_tools.context_namespace import BuiltinContextNamespace
from ...common import Dumper
from ...definitions import DebugTrail
from ...model_tools.definitions import OutputShape
from ...provider.essential import CannotProvide, Mediator
from ...provider.fields import output_field_to_loc_map
from ...provider.request_cls import DebugTrailRequest, TypeHintLoc
from ...provider.shape_provider import OutputShapeRequest, provide_generic_resolved_shape
from ..provider_template import DumperProvider
from ..request_cls import DumperRequest
from .basic_gen import (
CodeGenHookRequest,
ModelDumperGen,
NameSanitizer,
compile_closure_with_globals_capturing,
get_extra_targets_at_crown,
Expand All @@ -24,7 +24,7 @@
stub_code_gen_hook,
)
from .crown_definitions import OutputNameLayout, OutputNameLayoutRequest
from .dumper_gen import ModelDumperGen
from .dumper_gen import BuiltinModelDumperGen


class ModelDumperProvider(DumperProvider):
Expand All @@ -33,8 +33,8 @@ def __init__(self, *, name_sanitizer: NameSanitizer = NameSanitizer()):

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
dumper_gen = self._fetch_model_dumper_gen(mediator, request)
ctx_namespace = BuiltinContextNamespace()
dumper_code_builder = dumper_gen.produce_code(ctx_namespace)
closure_name = self._get_closure_name(request)
dumper_code, dumper_namespace = dumper_gen.produce_code(closure_name=closure_name)

try:
code_gen_hook = mediator.delegating_provide(CodeGenHookRequest(loc_stack=request.loc_stack))
Expand All @@ -44,14 +44,13 @@ def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return compile_closure_with_globals_capturing(
compiler=self._get_compiler(),
code_gen_hook=code_gen_hook,
namespace=ctx_namespace.dict,
body_builders=[dumper_code_builder],
closure_name=self._get_closure_name(request),
closure_params='data',
namespace=dumper_namespace,
closure_code=dumper_code,
closure_name=closure_name,
file_name=self._get_file_name(request),
)

def _fetch_model_dumper_gen(self, mediator: Mediator, request: DumperRequest) -> CodeGenerator:
def _fetch_model_dumper_gen(self, mediator: Mediator, request: DumperRequest) -> ModelDumperGen:
shape = self._fetch_shape(mediator, request)
name_layout = self._fetch_name_layout(mediator, request, shape)
shape = self._process_shape(shape, name_layout)
Expand Down Expand Up @@ -86,8 +85,8 @@ def _create_model_dumper_gen(
name_layout: OutputNameLayout,
fields_dumpers: Mapping[str, Dumper],
model_identity: str,
) -> CodeGenerator:
return ModelDumperGen(
) -> ModelDumperGen:
return BuiltinModelDumperGen(
shape=shape,
name_layout=name_layout,
debug_trail=debug_trail,
Expand Down
19 changes: 12 additions & 7 deletions src/adaptix/_internal/morphing/model/loader_gen.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import collections.abc
import contextlib
from dataclasses import dataclass
from typing import Dict, List, Mapping, Optional, Set
from typing import Dict, List, Mapping, Optional, Set, Tuple

from ...code_generator import CodeGenerator
from ...code_tools.code_builder import CodeBuilder
from ...code_tools.context_namespace import ContextNamespace
from ...code_tools.context_namespace import BuiltinContextNamespace, ContextNamespace
from ...code_tools.utils import get_literal_expr, get_literal_from_factory
from ...common import Loader
from ...compat import CompatExceptionGroup
Expand All @@ -23,6 +22,7 @@
NoRequiredItemsError,
TypeLoadError,
)
from .basic_gen import ModelLoaderGen
from .crown_definitions import (
BranchInpCrown,
CrownPath,
Expand Down Expand Up @@ -168,8 +168,8 @@ class ModelLoaderProps:
use_default_for_omitted: bool = True


class ModelLoaderGen(CodeGenerator):
"""ModelLoaderGen generates code that extracts raw values from input data,
class BuiltinModelLoaderGen(ModelLoaderGen):
"""BuiltinModelLoaderGen generates code that extracts raw values from input data,
calls loaders and stores results to variables.
"""

Expand Down Expand Up @@ -226,7 +226,8 @@ def _is_packed_field(self, field: InputField) -> bool:
return False
return field.is_optional and not self._is_extra_target(field)

def produce_code(self, ctx_namespace: ContextNamespace) -> CodeBuilder:
def produce_code(self, closure_name: str) -> Tuple[str, Mapping[str, object]]:
ctx_namespace = BuiltinContextNamespace()
state = self._create_state(ctx_namespace)

for field_id, loader in self._field_loaders.items():
Expand Down Expand Up @@ -278,7 +279,11 @@ def produce_code(self, ctx_namespace: ContextNamespace) -> CodeBuilder:

self._gen_constructor_call(state)
self._gen_header(state)
return state.builder

builder = CodeBuilder()
with builder(f'def {closure_name}(data):'):
builder.extend(state.builder)
return builder.string(), ctx_namespace.dict

def _gen_header(self, state: GenState):
header_builder = CodeBuilder()
Expand Down
25 changes: 12 additions & 13 deletions src/adaptix/_internal/morphing/model/loader_provider.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import Mapping

from ...code_generator import CodeGenerator
from adaptix._internal.provider.fields import input_field_to_loc_map

from ...code_tools.compiler import BasicClosureCompiler
from ...code_tools.context_namespace import BuiltinContextNamespace
from ...common import Loader
from ...definitions import DebugTrail
from ...model_tools.definitions import InputShape
from ...provider.essential import CannotProvide, Mediator
from ...provider.fields import input_field_to_loc_map
from ...provider.request_cls import DebugTrailRequest, StrictCoercionRequest, TypeHintLoc
from ...provider.shape_provider import InputShapeRequest, provide_generic_resolved_shape
from ..model.loader_gen import ModelLoaderGen, ModelLoaderProps
from ..model.loader_gen import BuiltinModelLoaderGen, ModelLoaderProps
from ..provider_template import LoaderProvider
from ..request_cls import LoaderRequest
from .basic_gen import (
CodeGenHookRequest,
ModelLoaderGen,
NameSanitizer,
compile_closure_with_globals_capturing,
get_extra_targets_at_crown,
Expand All @@ -40,8 +40,8 @@ def __init__(

def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
loader_gen = self._fetch_model_loader_gen(mediator, request)
ctx_namespace = BuiltinContextNamespace()
loader_code_builder = loader_gen.produce_code(ctx_namespace)
closure_name = self._get_closure_name(request)
loader_code, loader_namespace = loader_gen.produce_code(closure_name=closure_name)

try:
code_gen_hook = mediator.delegating_provide(CodeGenHookRequest(loc_stack=request.loc_stack))
Expand All @@ -51,14 +51,13 @@ def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
return compile_closure_with_globals_capturing(
compiler=self._get_compiler(),
code_gen_hook=code_gen_hook,
namespace=ctx_namespace.dict,
body_builders=[loader_code_builder],
closure_name=self._get_closure_name(request),
closure_params='data',
namespace=loader_namespace,
closure_code=loader_code,
closure_name=closure_name,
file_name=self._get_file_name(request),
)

def _fetch_model_loader_gen(self, mediator: Mediator, request: LoaderRequest) -> CodeGenerator:
def _fetch_model_loader_gen(self, mediator: Mediator, request: LoaderRequest) -> ModelLoaderGen:
shape = self._fetch_shape(mediator, request)
name_layout = self._fetch_name_layout(mediator, request, shape)
shape = self._process_shape(shape, name_layout)
Expand Down Expand Up @@ -97,8 +96,8 @@ def _create_model_loader_gen(
name_layout: InputNameLayout,
field_loaders: Mapping[str, Loader],
model_identity: str,
) -> CodeGenerator:
return ModelLoaderGen(
) -> ModelLoaderGen:
return BuiltinModelLoaderGen(
shape=shape,
name_layout=name_layout,
debug_trail=debug_trail,
Expand Down

0 comments on commit 0b9cde5

Please sign in to comment.