Skip to content

Commit

Permalink
fix queries when subquery has a union
Browse files Browse the repository at this point in the history
  • Loading branch information
WaVEV authored and timgraham committed Nov 19, 2024
1 parent b730288 commit c49e345
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 38 deletions.
17 changes: 13 additions & 4 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,16 @@ def project_field(column):
)

@cached_property
def collection_name(self):
base_table = next(
def base_table(self):
return next(
v
for k, v in self.query.alias_map.items()
if isinstance(v, BaseTable) and self.query.alias_refcount[k]
)
return base_table.table_alias or base_table.table_name

@cached_property
def collection_name(self):
return self.base_table.table_alias or self.base_table.table_name

@cached_property
def collection(self):
Expand Down Expand Up @@ -469,6 +472,7 @@ def get_combinator_queries(self):
)
)
compiler_.pre_sql_setup()
compiler_.column_indices = self.column_indices
columns = compiler_.get_columns()
parts.append((compiler_.build_query(columns), compiler_, columns))
except EmptyResultSet:
Expand Down Expand Up @@ -496,7 +500,12 @@ def get_combinator_queries(self):
# Combine query with the current combinator pipeline.
if combinator_pipeline:
combinator_pipeline.append(
{"$unionWith": {"coll": compiler_.collection_name, "pipeline": inner_pipeline}}
{
"$unionWith": {
"coll": compiler_.base_table.table_name,
"pipeline": inner_pipeline,
}
}
)
else:
combinator_pipeline = inner_pipeline
Expand Down
63 changes: 34 additions & 29 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,39 +122,44 @@ def query(self, compiler, connection, lookup_name=None):
if lookup_name in ("in", "range"):
if subquery.aggregation_pipeline is None:
subquery.aggregation_pipeline = []
subquery.aggregation_pipeline.extend(
[
{
"$facet": {
"group": [
wrapping_result_pipeline = [
{
"$facet": {
"group": [
{
"$group": {
"_id": None,
"tmp_name": {
"$addToSet": expr.as_mql(subquery_compiler, connection)
},
}
}
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$group": {
"_id": None,
"tmp_name": {
"$addToSet": expr.as_mql(subquery_compiler, connection)
},
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
}
},
[],
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
}
},
]
)
}
},
]
# If the subquery is a combinator, wrap the result at the end of the
# combinator pipeline...
if subquery.query.combinator:
subquery.combinator_pipeline.extend(wrapping_result_pipeline)
# ... otherwise put at the end of subquery's pipeline.
else:
subquery.aggregation_pipeline.extend(wrapping_result_pipeline)
# Erase project_fields since the required value is projected above.
subquery.project_fields = None
compiler.subqueries.append(subquery)
Expand Down
5 changes: 0 additions & 5 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# Connection creation doesn't follow the usual Django API.
"backends.tests.ThreadTests.test_pass_connection_between_threads",
"backends.tests.ThreadTests.test_default_connection_thread_local",
# Union as subquery is not mapping the parent parameter and collections:
# https://github.com/mongodb-labs/django-mongodb/issues/156
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery_related_outerref",
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery",
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_with_ordering",
# ObjectId type mismatch in a subquery:
# https://github.com/mongodb-labs/django-mongodb/issues/161
"queries.tests.RelatedLookupTypeTests.test_values_queryset_lookup",
Expand Down

0 comments on commit c49e345

Please sign in to comment.