Skip to content

Commit

Permalink
add support for partial indexes
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Oct 14, 2024
1 parent 85d32a5 commit 07e5a17
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 9 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ jobs:
from_db_value
generic_relations
generic_relations_regress
indexes
introspection
known_related_objects
lookup
Expand Down
2 changes: 2 additions & 0 deletions django_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from .expressions import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .indexes import register_indexes # noqa: E402
from .lookups import register_lookups # noqa: E402
from .query import register_nodes # noqa: E402

register_aggregates()
register_expressions()
register_fields()
register_functions()
register_indexes()
register_lookups()
register_nodes()
3 changes: 2 additions & 1 deletion django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def case(self, compiler, connection):

def col(self, compiler, connection): # noqa: ARG001
# Add the column's collection's alias for columns in joined collections.
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
return f"${prefix}{self.target.column}"


Expand Down
5 changes: 3 additions & 2 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# BSON Date type doesn't support microsecond precision.
supports_microsecond_precision = False
supports_paramstyle_pyformat = False
# Not implemented.
supports_partial_indexes = False
supports_select_difference = False
supports_select_intersection = False
supports_sequence_reset = False
Expand Down Expand Up @@ -72,6 +70,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"backends.tests.ThreadTests.test_pass_connection_between_threads",
"backends.tests.ThreadTests.test_closing_non_shared_connections",
"backends.tests.ThreadTests.test_default_connection_thread_local",
# TODO:
"indexes.tests.PartialIndexTests.test_is_null_condition",
"indexes.tests.PartialIndexTests.test_multiple_conditions",
}
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
_django_test_expected_failures_bitwise = {
Expand Down
20 changes: 20 additions & 0 deletions django_mongodb/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from django.db.models import Index
from django.db.models.sql.query import Query


def _get_condition_mql(self, model, schema_editor):
"""Analogous to Index._get_condition_sql()."""
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
mql_ = where.as_mql(compiler, schema_editor.connection)
# Transform aggregate() query syntax into find() syntax.
mql = {}
for key in mql_:
col, value = mql_[key]
mql[col.lstrip("$")] = {key: value}
return mql


def register_indexes():
Index._get_condition_mql = _get_condition_mql
28 changes: 22 additions & 6 deletions django_mongodb/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict

from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import Index, UniqueConstraint
from pymongo import ASCENDING, DESCENDING
Expand Down Expand Up @@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False):
if index.contains_expressions:
return
kwargs = {}
filter_expression = defaultdict(dict)
if index.condition:
filter_expression.update(index._get_condition_mql(model, self))
if unique:
filter_expression = {}
kwargs["unique"] = True
if field:
filter_expression[field.column] = {"$type": field.db_type(self.connection)}
filter_expression[field.column].update({"$type": field.db_type(self.connection)})
else:
for field_name, _ in index.fields_orders:
field_ = model._meta.get_field(field_name)
filter_expression[field_.column] = {"$type": field_.db_type(self.connection)}
filter_expression[field_.column].update(
{"$type": field_.db_type(self.connection)}
)
# Use partialFilterExpression to allow multiple null values for a
# unique constraint.
kwargs = {"partialFilterExpression": filter_expression, "unique": True}
if filter_expression:
kwargs["partialFilterExpression"] = filter_expression
index_orders = (
[(field.column, ASCENDING)]
if field
Expand Down Expand Up @@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None):
expressions=constraint.expressions,
nulls_distinct=constraint.nulls_distinct,
):
idx = Index(fields=constraint.fields, name=constraint.name)
idx = Index(
fields=constraint.fields,
condition=constraint.condition,
name=constraint.name,
)
self.add_index(model, idx, field=field, unique=True)

def _add_field_unique(self, model, field):
Expand All @@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint):
expressions=constraint.expressions,
nulls_distinct=constraint.nulls_distinct,
):
idx = Index(fields=constraint.fields, name=constraint.name)
idx = Index(
fields=constraint.fields,
condition=constraint.condition,
name=constraint.name,
)
self.remove_index(model, idx)

def _remove_field_unique(self, model, field):
Expand Down

0 comments on commit 07e5a17

Please sign in to comment.