diff --git a/CHANGES.md b/CHANGES.md index 3b65ced..e731e52 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,8 @@ ## Unreleased - Added/reactivated documentation as `sqlalchemy-cratedb` +- Added `CrateIdentifierPreparer`, in order to quote reserved words + like `object` properly, for example when used as column names. ## 2024/06/13 0.37.0 - Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index d3a6618..6ce1b7e 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -25,6 +25,7 @@ import sqlalchemy as sa from sqlalchemy.dialects.postgresql.base import PGCompiler +from sqlalchemy.dialects.postgresql.base import RESERVED_WORDS as POSTGRESQL_RESERVED_WORDS from sqlalchemy.sql import compiler from sqlalchemy.types import String from .type.geo import Geopoint, Geoshape @@ -323,3 +324,44 @@ def for_update_clause(self, select, **kw): warnings.warn("CrateDB does not support the 'INSERT ... FOR UPDATE' clause, " "it will be omitted when generating SQL statements.") return '' + + +CRATEDB_RESERVED_WORDS = \ + "cross, current_date, intersect, else, end, except, using, case, and, current_schema, any, " \ + "all, set, limit, input, natural, cast, directory, is, when, if, table, right, outer, full, " \ + "order, select, join, add, session_user, current_time, grant, true, left, into, try_cast, " \ + "current_role, insert, some, exists, update, false, create, reset, offset, object, " \ + "transient, current_user, in, or, for, alter, asc, function, null, from, default, not, " \ + "like, union, distinct, nulls, having, inner, by, persistent, stratify, array, revoke, " \ + "match, drop, escape, where, costs, with, group, index, delete, column, on, unbounded, " \ + "returns, then, last, user, called, recursive, between, describe, as, extract, " \ + "current_timestamp, deny, first, constraint, desc".split(", ") + + +class CrateIdentifierPreparer(sa.sql.compiler.IdentifierPreparer): + """ + Define CrateDB's reserved words to be quoted properly. + """ + reserved_words = set(list(POSTGRESQL_RESERVED_WORDS) + CRATEDB_RESERVED_WORDS) + + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].replace( + self.escape_to_quote, self.escape_quote + ) + return value + + def format_type(self, type_, use_schema=True): + if not type_.name: + raise sa.exc.CompileError("Type requires a name.") + + name = self.quote(type_.name) + effective_schema = self.schema_for_object(type_) + + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): + name = self.quote_schema(effective_schema) + "." + name + return name diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 43af2fc..c5cc2de 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -29,7 +29,8 @@ from .compiler import ( CrateTypeCompiler, - CrateDDLCompiler + CrateDDLCompiler, + CrateIdentifierPreparer, ) from crate.client.exceptions import TimezoneUnawareException from .sa_version import SA_VERSION, SA_1_4, SA_2_0 @@ -174,6 +175,7 @@ class CrateDialect(default.DefaultDialect): statement_compiler = statement_compiler ddl_compiler = CrateDDLCompiler type_compiler = CrateTypeCompiler + preparer = CrateIdentifierPreparer use_insertmanyvalues = True use_insertmanyvalues_wo_returning = True supports_multivalues_insert = True diff --git a/tests/compiler_test.py b/tests/compiler_test.py index a40ebb0..91f7d8d 100644 --- a/tests/compiler_test.py +++ b/tests/compiler_test.py @@ -432,3 +432,31 @@ class FooBar(Base): self.assertIsSubclass(w[-1].category, UserWarning) self.assertIn("CrateDB does not support unique constraints, " "they will be omitted when generating DDL statements.", str(w[-1].message)) + + def test_ddl_with_reserved_words(self): + """ + Verify CrateDB's reserved words like `object` are quoted properly. + """ + + Base = declarative_base(metadata=self.metadata) + + class FooBar(Base): + """The entity.""" + + __tablename__ = "foobar" + + index = sa.Column(sa.Integer, primary_key=True) + array = sa.Column(sa.String) + object = sa.Column(sa.String) + + # Verify SQL DDL statement. + self.metadata.create_all(self.engine, tables=[FooBar.__table__], checkfirst=False) + self.assertEqual(self.executed_statement, dedent(""" + CREATE TABLE testdrive.foobar ( + \t"index" INT NOT NULL, + \t"array" STRING, + \t"object" STRING, + \tPRIMARY KEY ("index") + ) + + """)) # noqa: W291, W293