Skip to content

Commit

Permalink
Refactor data type implementations to use their own file each
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Dec 21, 2023
1 parent af5f967 commit 70bdfa1
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 146 deletions.
5 changes: 3 additions & 2 deletions src/sqlalchemy_cratedb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from .dialect import CrateDialect
from .sa_version import SA_1_4, SA_2_0, SA_VERSION
from .support import insert_bulk
from .types import Geopoint, Geoshape, ObjectArray, ObjectType

from .type.array import ObjectArray
from .type.geo import Geopoint, Geoshape
from .type.object import ObjectType

if SA_VERSION < SA_1_4:
import textwrap
Expand Down
3 changes: 2 additions & 1 deletion src/sqlalchemy_cratedb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from sqlalchemy.dialects.postgresql.base import PGCompiler
from sqlalchemy.sql import compiler
from sqlalchemy.types import String
from .types import MutableDict, ObjectTypeImpl, Geopoint, Geoshape
from .type.geo import Geopoint, Geoshape
from .type.object import MutableDict, ObjectTypeImpl
from .sa_version import SA_VERSION, SA_1_4


Expand Down
2 changes: 1 addition & 1 deletion src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from crate.client.exceptions import TimezoneUnawareException
from .sa_version import SA_VERSION, SA_1_4, SA_2_0
from .types import ObjectType, ObjectArray
from .type import ObjectArray, ObjectType

TYPES_MAP = {
"boolean": sqltypes.Boolean,
Expand Down
3 changes: 3 additions & 0 deletions src/sqlalchemy_cratedb/type/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .array import ObjectArray
from .geo import Geopoint, Geoshape
from .object import ObjectType
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.
import warnings

import sqlalchemy.types as sqltypes
from sqlalchemy.sql import operators, expression
from sqlalchemy.sql import default_comparator
from sqlalchemy.ext.mutable import Mutable

import geojson


class MutableList(Mutable, list):

Expand Down Expand Up @@ -74,91 +71,6 @@ def remove(self, item):
self.changed()


class MutableDict(Mutable, dict):

@classmethod
def coerce(cls, key, value):
"Convert plain dictionaries to MutableDict."

if not isinstance(value, MutableDict):
if isinstance(value, dict):
return MutableDict(value)

# this call will raise ValueError
return Mutable.coerce(key, value)
else:
return value

def __init__(self, initval=None, to_update=None, root_change_key=None):
initval = initval or {}
self._changed_keys = set()
self._deleted_keys = set()
self._overwrite_key = root_change_key
self.to_update = self if to_update is None else to_update
for k in initval:
initval[k] = self._convert_dict(initval[k],
overwrite_key=k if self._overwrite_key is None else self._overwrite_key
)
dict.__init__(self, initval)

def __setitem__(self, key, value):
value = self._convert_dict(value, key if self._overwrite_key is None else self._overwrite_key)
dict.__setitem__(self, key, value)
self.to_update.on_key_changed(
key if self._overwrite_key is None else self._overwrite_key
)

def __delitem__(self, key):
dict.__delitem__(self, key)
# add the key to the deleted keys if this is the root object
# otherwise update on root object
if self._overwrite_key is None:
self._deleted_keys.add(key)
self.changed()
else:
self.to_update.on_key_changed(self._overwrite_key)

def on_key_changed(self, key):
self._deleted_keys.discard(key)
self._changed_keys.add(key)
self.changed()

def _convert_dict(self, value, overwrite_key):
if isinstance(value, dict) and not isinstance(value, MutableDict):
return MutableDict(value, self.to_update, overwrite_key)
return value

def __eq__(self, other):
return dict.__eq__(self, other)


class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON):

__visit_name__ = "OBJECT"

cache_ok = False
none_as_null = False


# Designated name to refer to. `Object` is too ambiguous.
ObjectType = MutableDict.as_mutable(ObjectTypeImpl)

# Backward-compatibility aliases.
_deprecated_Craty = ObjectType
_deprecated_Object = ObjectType

# https://www.lesinskis.com/deprecating-module-scope-variables.html
deprecated_names = ["Craty", "Object"]


def __getattr__(name):
if name in deprecated_names:
warnings.warn(f"{name} is deprecated and will be removed in future releases. "
f"Please use ObjectType instead.", DeprecationWarning)
return globals()[f"_deprecated_{name}"]
raise AttributeError(f"module {__name__} has no attribute {name}")


class Any(expression.ColumnElement):
"""Represent the clause ``left operator ANY (right)``. ``right`` must be
an array expression.
Expand Down Expand Up @@ -230,48 +142,3 @@ def get_col_spec(self, **kws):


ObjectArray = MutableList.as_mutable(_ObjectArray)


class Geopoint(sqltypes.UserDefinedType):
cache_ok = True

class Comparator(sqltypes.TypeEngine.Comparator):

def __getitem__(self, key):
return default_comparator._binary_operate(self.expr,
operators.getitem,
key)

def get_col_spec(self):
return 'GEO_POINT'

def bind_processor(self, dialect):
def process(value):
if isinstance(value, geojson.Point):
return value.coordinates
return value
return process

def result_processor(self, dialect, coltype):
return tuple

comparator_factory = Comparator


class Geoshape(sqltypes.UserDefinedType):
cache_ok = True

class Comparator(sqltypes.TypeEngine.Comparator):

def __getitem__(self, key):
return default_comparator._binary_operate(self.expr,
operators.getitem,
key)

def get_col_spec(self):
return 'GEO_SHAPE'

def result_processor(self, dialect, coltype):
return geojson.GeoJSON.to_instance

comparator_factory = Comparator
48 changes: 48 additions & 0 deletions src/sqlalchemy_cratedb/type/geo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import geojson
from sqlalchemy import types as sqltypes
from sqlalchemy.sql import default_comparator, operators


class Geopoint(sqltypes.UserDefinedType):
cache_ok = True

class Comparator(sqltypes.TypeEngine.Comparator):

def __getitem__(self, key):
return default_comparator._binary_operate(self.expr,
operators.getitem,
key)

def get_col_spec(self):
return 'GEO_POINT'

def bind_processor(self, dialect):
def process(value):
if isinstance(value, geojson.Point):
return value.coordinates
return value
return process

def result_processor(self, dialect, coltype):
return tuple

comparator_factory = Comparator


class Geoshape(sqltypes.UserDefinedType):
cache_ok = True

class Comparator(sqltypes.TypeEngine.Comparator):

def __getitem__(self, key):
return default_comparator._binary_operate(self.expr,
operators.getitem,
key)

def get_col_spec(self):
return 'GEO_SHAPE'

def result_processor(self, dialect, coltype):
return geojson.GeoJSON.to_instance

comparator_factory = Comparator
92 changes: 92 additions & 0 deletions src/sqlalchemy_cratedb/type/object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import warnings

from sqlalchemy import types as sqltypes
from sqlalchemy.ext.mutable import Mutable


class MutableDict(Mutable, dict):

@classmethod
def coerce(cls, key, value):
"Convert plain dictionaries to MutableDict."

if not isinstance(value, MutableDict):
if isinstance(value, dict):
return MutableDict(value)

# this call will raise ValueError
return Mutable.coerce(key, value)
else:
return value

def __init__(self, initval=None, to_update=None, root_change_key=None):
initval = initval or {}
self._changed_keys = set()
self._deleted_keys = set()
self._overwrite_key = root_change_key
self.to_update = self if to_update is None else to_update
for k in initval:
initval[k] = self._convert_dict(initval[k],
overwrite_key=k if self._overwrite_key is None else self._overwrite_key
)
dict.__init__(self, initval)

def __setitem__(self, key, value):
value = self._convert_dict(value, key if self._overwrite_key is None else self._overwrite_key)
dict.__setitem__(self, key, value)
self.to_update.on_key_changed(
key if self._overwrite_key is None else self._overwrite_key
)

def __delitem__(self, key):
dict.__delitem__(self, key)
# add the key to the deleted keys if this is the root object
# otherwise update on root object
if self._overwrite_key is None:
self._deleted_keys.add(key)
self.changed()
else:
self.to_update.on_key_changed(self._overwrite_key)

def on_key_changed(self, key):
self._deleted_keys.discard(key)
self._changed_keys.add(key)
self.changed()

def _convert_dict(self, value, overwrite_key):
if isinstance(value, dict) and not isinstance(value, MutableDict):
return MutableDict(value, self.to_update, overwrite_key)
return value

def __eq__(self, other):
return dict.__eq__(self, other)


class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON):

__visit_name__ = "OBJECT"

cache_ok = False
none_as_null = False


# Designated name to refer to. `Object` is too ambiguous.
ObjectType = MutableDict.as_mutable(ObjectTypeImpl)

# Backward-compatibility aliases.
_deprecated_Craty = ObjectType
_deprecated_Object = ObjectType

# https://www.lesinskis.com/deprecating-module-scope-variables.html
deprecated_names = ["Craty", "Object"]


def __getattr__(name):
if name in deprecated_names:
warnings.warn(f"{name} is deprecated and will be removed in future releases. "
f"Please use ObjectType instead.", DeprecationWarning)
return globals()[f"_deprecated_{name}"]
raise AttributeError(f"module {__name__} has no attribute {name}")


__all__ = deprecated_names
3 changes: 1 addition & 2 deletions tests/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
except ImportError:
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy_cratedb import SA_VERSION, SA_1_4, SA_2_0
from sqlalchemy_cratedb import ObjectType
from sqlalchemy_cratedb import SA_VERSION, SA_1_4, SA_2_0, ObjectType
from crate.client.test_util import ParametrizedTestCase


Expand Down
3 changes: 1 addition & 2 deletions tests/dialect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
import sqlalchemy as sa

from crate.client.cursor import Cursor
from sqlalchemy_cratedb import SA_VERSION
from sqlalchemy_cratedb import SA_VERSION, ObjectType
from sqlalchemy_cratedb import SA_1_4, SA_2_0
from sqlalchemy_cratedb import ObjectType
from sqlalchemy import inspect
from sqlalchemy.orm import Session
try:
Expand Down
4 changes: 1 addition & 3 deletions tests/query_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql.operators import eq

from sqlalchemy_cratedb import SA_VERSION, SA_1_4
from sqlalchemy_cratedb import SA_VERSION, SA_1_4, ObjectArray, ObjectType
from crate.testing.settings import crate_host

try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy_cratedb import ObjectType, ObjectArray


class SqlAlchemyQueryCompilationCaching(TestCase):

Expand Down
4 changes: 2 additions & 2 deletions tests/warnings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_craty_object_deprecation_warning(self):
with warnings.catch_warnings(record=True) as w:

# Import the deprecated symbol.
from sqlalchemy_cratedb.types import Craty # noqa: F401
from sqlalchemy_cratedb.type.object import Craty # noqa: F401

# Verify details of the deprecation warning.
self.assertEqual(len(w), 1)
Expand All @@ -55,7 +55,7 @@ def test_craty_object_deprecation_warning(self):
with warnings.catch_warnings(record=True) as w:

# Import the deprecated symbol.
from sqlalchemy_cratedb.types import Object # noqa: F401
from sqlalchemy_cratedb.type.object import Object # noqa: F401

# Verify details of the deprecation warning.
self.assertEqual(len(w), 1)
Expand Down

0 comments on commit 70bdfa1

Please sign in to comment.