diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index d8fc02ea..4168ef34 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -86,6 +86,7 @@ jobs: from_db_value generic_relations generic_relations_regress + indexes introspection known_related_objects lookup diff --git a/django_mongodb/__init__.py b/django_mongodb/__init__.py index 7994999d..31d8f2d3 100644 --- a/django_mongodb/__init__.py +++ b/django_mongodb/__init__.py @@ -10,6 +10,7 @@ from .expressions import register_expressions # noqa: E402 from .fields import register_fields # noqa: E402 from .functions import register_functions # noqa: E402 +from .indexes import register_indexes # noqa: E402 from .lookups import register_lookups # noqa: E402 from .query import register_nodes # noqa: E402 @@ -17,5 +18,6 @@ register_expressions() register_fields() register_functions() +register_indexes() register_lookups() register_nodes() diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index 7af3d71e..13ab606f 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -51,7 +51,8 @@ def case(self, compiler, connection): def col(self, compiler, connection): # noqa: ARG001 # Add the column's collection's alias for columns in joined collections. - prefix = f"{self.alias}." if self.alias != compiler.collection_name else "" + has_alias = self.alias and self.alias != compiler.collection_name + prefix = f"{self.alias}." if has_alias else "" return f"${prefix}{self.target.column}" diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 7a30baf2..454b1409 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -23,8 +23,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): # BSON Date type doesn't support microsecond precision. supports_microsecond_precision = False supports_paramstyle_pyformat = False - # Not implemented. - supports_partial_indexes = False supports_select_difference = False supports_select_intersection = False supports_sequence_reset = False @@ -72,6 +70,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): "backends.tests.ThreadTests.test_pass_connection_between_threads", "backends.tests.ThreadTests.test_closing_non_shared_connections", "backends.tests.ThreadTests.test_default_connection_thread_local", + # TODO: + "indexes.tests.PartialIndexTests.test_is_null_condition", + "indexes.tests.PartialIndexTests.test_multiple_conditions", } # $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3. _django_test_expected_failures_bitwise = { diff --git a/django_mongodb/indexes.py b/django_mongodb/indexes.py new file mode 100644 index 00000000..255b2983 --- /dev/null +++ b/django_mongodb/indexes.py @@ -0,0 +1,20 @@ +from django.db.models import Index +from django.db.models.sql.query import Query + + +def _get_condition_mql(self, model, schema_editor): + """Analogous to Index._get_condition_sql().""" + query = Query(model=model, alias_cols=False) + where = query.build_where(self.condition) + compiler = query.get_compiler(connection=schema_editor.connection) + mql_ = where.as_mql(compiler, schema_editor.connection) + # Transform aggregate() query syntax into find() syntax. + mql = {} + for key in mql_: + col, value = mql_[key] + mql[col.lstrip("$")] = {key: value} + return mql + + +def register_indexes(): + Index._get_condition_mql = _get_condition_mql diff --git a/django_mongodb/schema.py b/django_mongodb/schema.py index 02028cbc..8e85a695 100644 --- a/django_mongodb/schema.py +++ b/django_mongodb/schema.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.models import Index, UniqueConstraint from pymongo import ASCENDING, DESCENDING @@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False): if index.contains_expressions: return kwargs = {} + filter_expression = defaultdict(dict) + if index.condition: + filter_expression.update(index._get_condition_mql(model, self)) if unique: - filter_expression = {} + kwargs["unique"] = True if field: - filter_expression[field.column] = {"$type": field.db_type(self.connection)} + filter_expression[field.column].update({"$type": field.db_type(self.connection)}) else: for field_name, _ in index.fields_orders: field_ = model._meta.get_field(field_name) - filter_expression[field_.column] = {"$type": field_.db_type(self.connection)} + filter_expression[field_.column].update( + {"$type": field_.db_type(self.connection)} + ) # Use partialFilterExpression to allow multiple null values for a # unique constraint. - kwargs = {"partialFilterExpression": filter_expression, "unique": True} + if filter_expression: + kwargs["partialFilterExpression"] = filter_expression index_orders = ( [(field.column, ASCENDING)] if field @@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None): expressions=constraint.expressions, nulls_distinct=constraint.nulls_distinct, ): - idx = Index(fields=constraint.fields, name=constraint.name) + idx = Index( + fields=constraint.fields, + condition=constraint.condition, + name=constraint.name, + ) self.add_index(model, idx, field=field, unique=True) def _add_field_unique(self, model, field): @@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint): expressions=constraint.expressions, nulls_distinct=constraint.nulls_distinct, ): - idx = Index(fields=constraint.fields, name=constraint.name) + idx = Index( + fields=constraint.fields, + condition=constraint.condition, + name=constraint.name, + ) self.remove_index(model, idx) def _remove_field_unique(self, model, field):