Skip to content

Commit

Permalink
sqlalchemy instrospecion initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
daler-sz committed Jan 28, 2024
1 parent 7ff21d7 commit 2f311e7
Show file tree
Hide file tree
Showing 3 changed files with 361 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ classifiers = [
[project.optional-dependencies]
attrs = ['attrs >= 21.3.0']
attrs-strict = ['attrs >= 21.3.0, <= 23.1.0']
sqlalchemy = ['sqlalchemy >= 2.0.0']

[project.urls]
'Homepage' = 'https://github.com/reagento/dataclass_factory'
Expand Down
152 changes: 152 additions & 0 deletions src/adaptix/_internal/model_tools/introspection/sqlalchemy_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from inspect import getfullargspec
from typing import Any

from sqlalchemy import inspect
from sqlalchemy.sql.schema import CallableColumnDefault, ScalarElementColumnDefault

from adaptix._internal.model_tools.definitions import (
DefaultFactory,
DefaultValue,
FullShape,
InputField,
InputShape,
NoDefault,
OutputField,
OutputShape,
Param,
ParamKind,
Shape,
create_attr_accessor,
)
from adaptix._internal.type_tools import get_all_type_hints


class ColumnPropertyWrapper:
def __init__(self, column_property):
self.column_property = column_property


def _is_context_sensitive(default):
try:
wrapped_callable = default.arg.__wrapped__
except AttributeError:
return True

spec = getfullargspec(wrapped_callable)
return len(spec.args) > 0


def _get_sqlalchemy_type_for_column(column, type_hints):
try:
return type_hints[column.name].__args__[0]
except KeyError:
return column.type.python_type


def _get_sqlalchemy_type_for_relationship(relationship, type_hints):
try:
return type_hints[str(relationship).split(".")[1]].__args__[0]
except KeyError:
return Any


def _get_sqlalchemy_default(column_default):
if not column_default:
return NoDefault()
if isinstance(column_default, CallableColumnDefault) and not _is_context_sensitive(column_default):
return DefaultFactory(factory=column_default.arg.__wrapped__)
if isinstance(column_default, ScalarElementColumnDefault):
return DefaultValue(value=column_default.arg)
return NoDefault()


def _get_sqlalchemy_required(column):
if column.default or column.nullable or column.server_default:
return False
if column.primary_key and column.autoincrement and column.type.python_type is int:
return False
return True


def _get_sqlalchemy_input_shape(tp, columns, relationships, type_hints) -> InputShape:
param_name_to_field = {
column.name: InputField(
id=column.name,
type=_get_sqlalchemy_type_for_column(column, type_hints),
default=_get_sqlalchemy_default(column.default),
is_required=_get_sqlalchemy_required(column),
metadata=column.info,
original=ColumnPropertyWrapper(column_property=column)
)
for column in columns
}

for relationship in relationships:
name = str(relationship).split(".")[1]
param_name_to_field[name] = InputField(
id=name,
type=_get_sqlalchemy_type_for_relationship(relationship, type_hints),
default=NoDefault(),
is_required=False,
metadata={},
original=relationship
)

fields = tuple(param_name_to_field.values())

return InputShape(
constructor=tp,
fields=fields,
overriden_types=frozenset(),
kwargs=None,
params=tuple(
Param(
field_id=column.name,
name=column.name,
kind=ParamKind.KW_ONLY
)
for column in columns
)
)


def _get_sqlalchemy_output_shape(columns, relationships, type_hints) -> OutputShape:
output_fields = [
OutputField(
id=column.name,
type=_get_sqlalchemy_type_for_column(column, type_hints),
default=_get_sqlalchemy_default(column.default),
metadata=column.info,
original=ColumnPropertyWrapper(column_property=column),
accessor=create_attr_accessor(column.name, is_required=True)
)
for column in columns
]

for relationship in relationships:
name = str(relationship).split(".")[1]
output_fields.append(
OutputField(
id=name,
type=_get_sqlalchemy_type_for_relationship(relationship, type_hints),
default=NoDefault(),
metadata={},
original=relationship,
accessor=create_attr_accessor(name, is_required=False)
)
)

return OutputShape(
fields=tuple(output_fields),
overriden_types=frozenset()
)


def get_sqlalchemy_shape(tp) -> FullShape:
columns = inspect(tp).columns
relationships = inspect(tp).relationships
type_hints = get_all_type_hints(tp)
return Shape(
input=_get_sqlalchemy_input_shape(tp, columns, relationships, type_hints),
output=_get_sqlalchemy_output_shape(columns, relationships, type_hints)
)
208 changes: 208 additions & 0 deletions tests/unit/model_tools/introspection/test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from typing import Union
from unittest.mock import ANY

from sqlalchemy import Column, String
from sqlalchemy.orm import Mapped, declarative_base, mapped_column

from adaptix._internal.model_tools.definitions import (
DefaultFactory,
DefaultValue,
InputField,
InputShape,
NoDefault,
OutputField,
OutputShape,
Param,
ParamKind,
Shape,
create_attr_accessor,
)
from adaptix._internal.model_tools.introspection.sqlalchemy_tables import get_sqlalchemy_shape

Base = declarative_base()


def default_factory():
return 2


class MyTable(Base):
__tablename__ = "MyTable"

id: Mapped[int] = mapped_column(primary_key=True)
text: Mapped[str]
nullable_field: Mapped[Union[int, None]]
field_with_default: Mapped[int] = mapped_column(default=2)
field_with_default_factory: Mapped[int] = mapped_column(default=default_factory)
field_with_default_context_factory: Mapped[int] = mapped_column(default=lambda ctx: 2)
field_with_old_syntax = Column(String())


def test_shape_getter():
assert (
get_sqlalchemy_shape(MyTable)
==
Shape(
input=InputShape(
constructor=MyTable,
kwargs=None,
fields=(
InputField(
type=int,
id="id",
default=NoDefault(),
is_required=False,
metadata={},
original=ANY
),
InputField(
type=str,
id="text",
default=NoDefault(),
is_required=True,
metadata={},
original=ANY
),
InputField(
type=Union[int, None],
id="nullable_field",
default=NoDefault(),
is_required=False,
metadata={},
original=ANY
),
InputField(
type=int,
id="field_with_default",
default=DefaultValue(2),
is_required=False,
metadata={},
original=ANY
),
InputField(
type=int,
id="field_with_default_factory",
default=DefaultFactory(default_factory),
is_required=False,
metadata={},
original=ANY
),
InputField(
type=int,
id="field_with_default_context_factory",
default=NoDefault(),
is_required=False,
metadata={},
original=ANY
),
InputField(
type=str,
id="field_with_old_syntax",
default=NoDefault(),
is_required=False,
metadata={},
original=ANY
),
),
overriden_types=frozenset(),
params=(
Param(
field_id='id',
name='id',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='text',
name='text',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='nullable_field',
name='nullable_field',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='field_with_default',
name='field_with_default',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='field_with_default_factory',
name='field_with_default_factory',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='field_with_default_context_factory',
name='field_with_default_context_factory',
kind=ParamKind.KW_ONLY,
),
Param(
field_id='field_with_old_syntax',
name='field_with_old_syntax',
kind=ParamKind.KW_ONLY,
),
)
),
output=OutputShape(
fields=(
OutputField(
type=int,
id="id",
default=NoDefault(),
metadata={},
original=ANY,
accessor=create_attr_accessor('id', is_required=True),
),
OutputField(
type=str,
id="text",
default=NoDefault(),
metadata={},
original=ANY,
accessor=create_attr_accessor('text', is_required=True),
),
OutputField(
type=Union[int, None],
id="nullable_field",
default=NoDefault(),
metadata={},
original=ANY,
accessor=create_attr_accessor('nullable_field', is_required=True),
),
OutputField(
type=int,
id="field_with_default",
default=DefaultValue(2),
metadata={},
original=ANY,
accessor=create_attr_accessor('field_with_default', is_required=True),
),
OutputField(
type=int,
id="field_with_default_factory",
default=DefaultFactory(default_factory),
metadata={},
original=ANY,
accessor=create_attr_accessor('field_with_default_factory', is_required=True),
),
OutputField(
type=int,
id="field_with_default_context_factory",
default=NoDefault(),
metadata={},
original=ANY,
accessor=create_attr_accessor('field_with_default_context_factory', is_required=True),
),
OutputField(
type=str,
id="field_with_old_syntax",
default=NoDefault(),
metadata={},
original=ANY,
accessor=create_attr_accessor('field_with_old_syntax', is_required=True),
),
),
overriden_types=frozenset()
)
)
)

0 comments on commit 2f311e7

Please sign in to comment.