diff --git a/pyproject.toml b/pyproject.toml index c6c937d3..b0f4db5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ classifiers = [ [project.optional-dependencies] attrs = ['attrs >= 21.3.0'] attrs-strict = ['attrs >= 21.3.0, <= 23.1.0'] +sqlalchemy = ['sqlalchemy >= 2.0.0'] [project.urls] 'Homepage' = 'https://github.com/reagento/dataclass_factory' diff --git a/src/adaptix/_internal/model_tools/introspection/sqlalchemy_tables.py b/src/adaptix/_internal/model_tools/introspection/sqlalchemy_tables.py new file mode 100644 index 00000000..1d4b059d --- /dev/null +++ b/src/adaptix/_internal/model_tools/introspection/sqlalchemy_tables.py @@ -0,0 +1,152 @@ +from inspect import getfullargspec +from typing import Any + +from sqlalchemy import inspect +from sqlalchemy.sql.schema import CallableColumnDefault, ScalarElementColumnDefault + +from adaptix._internal.model_tools.definitions import ( + DefaultFactory, + DefaultValue, + FullShape, + InputField, + InputShape, + NoDefault, + OutputField, + OutputShape, + Param, + ParamKind, + Shape, + create_attr_accessor, +) +from adaptix._internal.type_tools import get_all_type_hints + + +class ColumnPropertyWrapper: + def __init__(self, column_property): + self.column_property = column_property + + +def _is_context_sensitive(default): + try: + wrapped_callable = default.arg.__wrapped__ + except AttributeError: + return True + + spec = getfullargspec(wrapped_callable) + return len(spec.args) > 0 + + +def _get_sqlalchemy_type_for_column(column, type_hints): + try: + return type_hints[column.name].__args__[0] + except KeyError: + return column.type.python_type + + +def _get_sqlalchemy_type_for_relationship(relationship, type_hints): + try: + return type_hints[str(relationship).split(".")[1]].__args__[0] + except KeyError: + return Any + + +def _get_sqlalchemy_default(column_default): + if not column_default: + return NoDefault() + if isinstance(column_default, CallableColumnDefault) and not _is_context_sensitive(column_default): + return DefaultFactory(factory=column_default.arg.__wrapped__) + if isinstance(column_default, ScalarElementColumnDefault): + return DefaultValue(value=column_default.arg) + return NoDefault() + + +def _get_sqlalchemy_required(column): + if column.default or column.nullable or column.server_default: + return False + if column.primary_key and column.autoincrement and column.type.python_type is int: + return False + return True + + +def _get_sqlalchemy_input_shape(tp, columns, relationships, type_hints) -> InputShape: + param_name_to_field = { + column.name: InputField( + id=column.name, + type=_get_sqlalchemy_type_for_column(column, type_hints), + default=_get_sqlalchemy_default(column.default), + is_required=_get_sqlalchemy_required(column), + metadata=column.info, + original=ColumnPropertyWrapper(column_property=column) + ) + for column in columns + } + + for relationship in relationships: + name = str(relationship).split(".")[1] + param_name_to_field[name] = InputField( + id=name, + type=_get_sqlalchemy_type_for_relationship(relationship, type_hints), + default=NoDefault(), + is_required=False, + metadata={}, + original=relationship + ) + + fields = tuple(param_name_to_field.values()) + + return InputShape( + constructor=tp, + fields=fields, + overriden_types=frozenset(), + kwargs=None, + params=tuple( + Param( + field_id=column.name, + name=column.name, + kind=ParamKind.KW_ONLY + ) + for column in columns + ) + ) + + +def _get_sqlalchemy_output_shape(columns, relationships, type_hints) -> OutputShape: + output_fields = [ + OutputField( + id=column.name, + type=_get_sqlalchemy_type_for_column(column, type_hints), + default=_get_sqlalchemy_default(column.default), + metadata=column.info, + original=ColumnPropertyWrapper(column_property=column), + accessor=create_attr_accessor(column.name, is_required=True) + ) + for column in columns + ] + + for relationship in relationships: + name = str(relationship).split(".")[1] + output_fields.append( + OutputField( + id=name, + type=_get_sqlalchemy_type_for_relationship(relationship, type_hints), + default=NoDefault(), + metadata={}, + original=relationship, + accessor=create_attr_accessor(name, is_required=False) + ) + ) + + return OutputShape( + fields=tuple(output_fields), + overriden_types=frozenset() + ) + + +def get_sqlalchemy_shape(tp) -> FullShape: + columns = inspect(tp).columns + relationships = inspect(tp).relationships + type_hints = get_all_type_hints(tp) + return Shape( + input=_get_sqlalchemy_input_shape(tp, columns, relationships, type_hints), + output=_get_sqlalchemy_output_shape(columns, relationships, type_hints) + ) diff --git a/tests/unit/model_tools/introspection/test_sqlalchemy.py b/tests/unit/model_tools/introspection/test_sqlalchemy.py new file mode 100644 index 00000000..0878a85d --- /dev/null +++ b/tests/unit/model_tools/introspection/test_sqlalchemy.py @@ -0,0 +1,208 @@ +from typing import Union +from unittest.mock import ANY + +from sqlalchemy import Column, String +from sqlalchemy.orm import Mapped, declarative_base, mapped_column + +from adaptix._internal.model_tools.definitions import ( + DefaultFactory, + DefaultValue, + InputField, + InputShape, + NoDefault, + OutputField, + OutputShape, + Param, + ParamKind, + Shape, + create_attr_accessor, +) +from adaptix._internal.model_tools.introspection.sqlalchemy_tables import get_sqlalchemy_shape + +Base = declarative_base() + + +def default_factory(): + return 2 + + +class MyTable(Base): + __tablename__ = "MyTable" + + id: Mapped[int] = mapped_column(primary_key=True) + text: Mapped[str] + nullable_field: Mapped[Union[int, None]] + field_with_default: Mapped[int] = mapped_column(default=2) + field_with_default_factory: Mapped[int] = mapped_column(default=default_factory) + field_with_default_context_factory: Mapped[int] = mapped_column(default=lambda ctx: 2) + field_with_old_syntax = Column(String()) + + +def test_shape_getter(): + assert ( + get_sqlalchemy_shape(MyTable) + == + Shape( + input=InputShape( + constructor=MyTable, + kwargs=None, + fields=( + InputField( + type=int, + id="id", + default=NoDefault(), + is_required=False, + metadata={}, + original=ANY + ), + InputField( + type=str, + id="text", + default=NoDefault(), + is_required=True, + metadata={}, + original=ANY + ), + InputField( + type=Union[int, None], + id="nullable_field", + default=NoDefault(), + is_required=False, + metadata={}, + original=ANY + ), + InputField( + type=int, + id="field_with_default", + default=DefaultValue(2), + is_required=False, + metadata={}, + original=ANY + ), + InputField( + type=int, + id="field_with_default_factory", + default=DefaultFactory(default_factory), + is_required=False, + metadata={}, + original=ANY + ), + InputField( + type=int, + id="field_with_default_context_factory", + default=NoDefault(), + is_required=False, + metadata={}, + original=ANY + ), + InputField( + type=str, + id="field_with_old_syntax", + default=NoDefault(), + is_required=False, + metadata={}, + original=ANY + ), + ), + overriden_types=frozenset(), + params=( + Param( + field_id='id', + name='id', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='text', + name='text', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='nullable_field', + name='nullable_field', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='field_with_default', + name='field_with_default', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='field_with_default_factory', + name='field_with_default_factory', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='field_with_default_context_factory', + name='field_with_default_context_factory', + kind=ParamKind.KW_ONLY, + ), + Param( + field_id='field_with_old_syntax', + name='field_with_old_syntax', + kind=ParamKind.KW_ONLY, + ), + ) + ), + output=OutputShape( + fields=( + OutputField( + type=int, + id="id", + default=NoDefault(), + metadata={}, + original=ANY, + accessor=create_attr_accessor('id', is_required=True), + ), + OutputField( + type=str, + id="text", + default=NoDefault(), + metadata={}, + original=ANY, + accessor=create_attr_accessor('text', is_required=True), + ), + OutputField( + type=Union[int, None], + id="nullable_field", + default=NoDefault(), + metadata={}, + original=ANY, + accessor=create_attr_accessor('nullable_field', is_required=True), + ), + OutputField( + type=int, + id="field_with_default", + default=DefaultValue(2), + metadata={}, + original=ANY, + accessor=create_attr_accessor('field_with_default', is_required=True), + ), + OutputField( + type=int, + id="field_with_default_factory", + default=DefaultFactory(default_factory), + metadata={}, + original=ANY, + accessor=create_attr_accessor('field_with_default_factory', is_required=True), + ), + OutputField( + type=int, + id="field_with_default_context_factory", + default=NoDefault(), + metadata={}, + original=ANY, + accessor=create_attr_accessor('field_with_default_context_factory', is_required=True), + ), + OutputField( + type=str, + id="field_with_old_syntax", + default=NoDefault(), + metadata={}, + original=ANY, + accessor=create_attr_accessor('field_with_old_syntax', is_required=True), + ), + ), + overriden_types=frozenset() + ) + ) + )