From 67d9d3f164f27d663fa3a5f729cb9453f5e3e7f5 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 21 Dec 2024 00:31:47 -0300 Subject: [PATCH] Fix columns order in aggregation queries --- django_mongodb/compiler.py | 39 ++++++++++++++++++++++---------- django_mongodb/expressions.py | 6 ++++- django_mongodb/features.py | 14 ------------ tests/indexes_/test_condition.py | 2 +- 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 47e7b2d3..83190d1b 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -403,12 +403,6 @@ def columns(self): columns = ( self.get_default_columns(select_mask) if self.query.default_cols else self.query.select ) - # Populate QuerySet.select_related() data. - related_columns = [] - if self.query.select_related: - self.get_related_selections(related_columns, select_mask) - if related_columns: - related_columns, _ = zip(*related_columns, strict=True) annotation_idx = 1 @@ -427,11 +421,28 @@ def project_field(column): annotation_idx += 1 return target, column - return ( - tuple(map(project_field, columns)) - + tuple(self.annotations.items()) - + tuple(map(project_field, related_columns)) - ) + selected = [] + if self.query.selected is None: + selected = [ + *(project_field(col) for col in columns), + *self.annotations.items(), + ] + else: + for expression in self.query.selected.values(): + # Reference to an annotation. + if isinstance(expression, str): + alias, expression = expression, self.annotations[expression] + # Reference to a column. + elif isinstance(expression, int): + alias, expression = project_field(columns[expression]) + selected.append((alias, expression)) + # Populate QuerySet.select_related() data. + related_columns = [] + if self.query.select_related: + self.get_related_selections(related_columns, select_mask) + if related_columns: + related_columns, _ = zip(*related_columns, strict=True) + return tuple(selected) + tuple(map(project_field, related_columns)) @cached_property def base_table(self): @@ -478,7 +489,11 @@ def get_combinator_queries(self): # If the columns list is limited, then all combined queries # must have the same columns list. Set the selects defined on # the query on all combined queries, if not already set. - if not compiler_.query.values_select and self.query.values_select: + selected = self.query.selected + if selected is not None and compiler_.query.selected is None: + compiler_.query = compiler_.query.clone() + compiler_.query.set_values(selected) + elif not compiler_.query.values_select and self.query.values_select: compiler_.query = compiler_.query.clone() compiler_.query.set_values( ( diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index 957c5f15..2e322d4b 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -178,7 +178,11 @@ def ref(self, compiler, connection): # noqa: ARG001 if isinstance(self.source, Col) and self.source.alias != compiler.collection_name else "" ) - return f"${prefix}{self.refs}" + if hasattr(self, "ordinal"): + refs, _ = compiler.columns[self.ordinal - 1] + else: + refs = self.refs + return f"${prefix}{refs}" def star(self, compiler, connection): # noqa: ARG001 diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 53c35df3..36a46b0c 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -88,20 +88,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key", # GenericRelation.value_to_string() assumes integer pk. "contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string", - # Broken by https://github.com/django/django/commit/65ad4ade74dc9208b9d686a451cd6045df0c9c3a - "aggregation.tests.AggregateTestCase.test_even_more_aggregate", - "aggregation.tests.AggregateTestCase.test_grouped_annotation_in_group_by", - "aggregation.tests.AggregateTestCase.test_non_grouped_annotation_not_in_group_by", - "aggregation_regress.tests.AggregationTests.test_aggregate_fexpr", - "aggregation_regress.tests.AggregationTests.test_values_list_annotation_args_ordering", - "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_subquery_and_aggregate_values_chaining", - "annotations.tests.NonAggregateAnnotationTestCase.test_values_fields_annotations_order", - "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_multiple_models_with_values_and_datetime_annotations", - "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_multiple_models_with_values_list_and_datetime_annotations", - "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_multiple_models_with_values_list_and_annotations", - "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_with_field_and_annotation_values", - "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_with_two_annotated_values_list", - "queries.tests.Queries1Tests.test_union_values_subquery", # pymongo.errors.WriteError: Performing an update on the path '_id' # would modify the immutable field '_id' "migrations.test_operations.OperationTests.test_composite_pk_operations", diff --git a/tests/indexes_/test_condition.py b/tests/indexes_/test_condition.py index f0d67b36..07415bdd 100644 --- a/tests/indexes_/test_condition.py +++ b/tests/indexes_/test_condition.py @@ -99,7 +99,7 @@ def test_composite_index(self): { "$and": [ {"number": {"$gte": 3}}, - {"$or": [{"body": {"$gt": "test1"}}, {"body": {"$in": ["A", "B"]}}]}, + {"$or": [{"body": {"$gt": "test1"}}, {"body": {"$in": ("A", "B")}}]}, ] }, )