From 419b97e713c242fb5108a4ad70316597b61b1377 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Thu, 9 Jan 2025 16:00:54 -0300 Subject: [PATCH] refactor subquery wrapping pipeline --- django_mongodb_backend/expressions.py | 86 ++++---------------------- django_mongodb_backend/fields/array.py | 35 +++++++++++ django_mongodb_backend/lookups.py | 33 ++++++++++ django_mongodb_backend/query_utils.py | 6 +- 4 files changed, 83 insertions(+), 77 deletions(-) diff --git a/django_mongodb_backend/expressions.py b/django_mongodb_backend/expressions.py index 3273ecdb..8e8c1815 100644 --- a/django_mongodb_backend/expressions.py +++ b/django_mongodb_backend/expressions.py @@ -95,7 +95,7 @@ def order_by(self, compiler, connection): return self.expression.as_mql(compiler, connection) -def query(self, compiler, connection, lookup_name=None): +def query(self, compiler, connection, get_wrapping_pipeline=None): subquery_compiler = self.get_compiler(connection=connection) subquery_compiler.pre_sql_setup(with_col_aliases=False) field_name, expr = subquery_compiler.columns[0] @@ -119,76 +119,12 @@ def query(self, compiler, connection, lookup_name=None): for col, i in subquery_compiler.column_indices.items() }, } - wrapping_result_pipeline = None - # The result must be a list of values. The output is compressed with an - # aggregation pipeline. - if lookup_name in ("in", "range"): - wrapping_result_pipeline = [ - { - "$facet": { - "group": [ - { - "$group": { - "_id": None, - "tmp_name": { - "$addToSet": expr.as_mql(subquery_compiler, connection) - }, - } - } - ] - } - }, - { - "$project": { - field_name: { - "$ifNull": [ - { - "$getField": { - "input": {"$arrayElemAt": ["$group", 0]}, - "field": "tmp_name", - } - }, - [], - ] - } - } - }, - ] - if lookup_name == "overlap": - wrapping_result_pipeline = [ - { - "$facet": { - "group": [ - {"$project": {"tmp_name": expr.as_mql(subquery_compiler, connection)}}, - { - "$unwind": "$tmp_name", - }, - { - "$group": { - "_id": None, - "tmp_name": {"$addToSet": "$tmp_name"}, - } - }, - ] - } - }, - { - "$project": { - field_name: { - "$ifNull": [ - { - "$getField": { - "input": {"$arrayElemAt": ["$group", 0]}, - "field": "tmp_name", - } - }, - [], - ] - } - } - }, - ] - if wrapping_result_pipeline: + if get_wrapping_pipeline: + # The results from some lookups must be converted to a list of values. + # The output is compressed with an aggregation pipeline. + wrapping_result_pipeline = get_wrapping_pipeline( + subquery_compiler, connection, field_name, expr + ) # If the subquery is a combinator, wrap the result at the end of the # combinator pipeline... if subquery.query.combinator: @@ -221,13 +157,13 @@ def star(self, compiler, connection): # noqa: ARG001 return {"$literal": True} -def subquery(self, compiler, connection, lookup_name=None): - return self.query.as_mql(compiler, connection, lookup_name=lookup_name) +def subquery(self, compiler, connection, get_wrapping_pipeline=None): + return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) -def exists(self, compiler, connection, lookup_name=None): +def exists(self, compiler, connection, get_wrapping_pipeline=None): try: - lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name) + lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) except EmptyResultSet: return Value(False).as_mql(compiler, connection) return connection.mongo_operators["isnull"](lhs_mql, False) diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 860233be..49c9e6ad 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -278,6 +278,41 @@ class ArrayExact(ArrayRHSMixin, Exact): class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "overlap" + def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): + return [ + { + "$facet": { + "group": [ + {"$project": {"tmp_name": expr.as_mql(compiler, connection)}}, + { + "$unwind": "$tmp_name", + }, + { + "$group": { + "_id": None, + "tmp_name": {"$addToSet": "$tmp_name"}, + } + }, + ] + } + }, + { + "$project": { + field_name: { + "$ifNull": [ + { + "$getField": { + "input": {"$arrayElemAt": ["$group", 0]}, + "field": "tmp_name", + } + }, + [], + ] + } + } + }, + ] + def as_mql(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) value = process_rhs(self, compiler, connection) diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index c651dd6a..519a03c9 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -45,6 +45,38 @@ def in_(self, compiler, connection): return builtin_lookup(self, compiler, connection) +def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001 + return [ + { + "$facet": { + "group": [ + { + "$group": { + "_id": None, + "tmp_name": {"$addToSet": expr.as_mql(compiler, connection)}, + } + } + ] + } + }, + { + "$project": { + field_name: { + "$ifNull": [ + { + "$getField": { + "input": {"$arrayElemAt": ["$group", 0]}, + "field": "tmp_name", + } + }, + [], + ] + } + } + }, + ] + + def is_null(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") @@ -97,6 +129,7 @@ def register_lookups(): field_resolve_expression_parameter ) In.as_mql = RelatedIn.as_mql = in_ + In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline IsNull.as_mql = is_null PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value UUIDTextMixin.as_mql = uuid_text_mixin diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index ff98a1ed..dd7042c7 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -28,8 +28,10 @@ def process_lhs(node, compiler, connection): def process_rhs(node, compiler, connection): rhs = node.rhs if hasattr(rhs, "as_mql"): - if getattr(rhs, "subquery", False): - value = rhs.as_mql(compiler, connection, lookup_name=node.lookup_name) + if getattr(rhs, "subquery", False) and hasattr(node, "get_subquery_wrapping_pipeline"): + value = rhs.as_mql( + compiler, connection, get_wrapping_pipeline=node.get_subquery_wrapping_pipeline + ) else: value = rhs.as_mql(compiler, connection) else: