diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 55749ede..1608d7d7 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -241,11 +241,10 @@ def execute_sql( self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE ): self.pre_sql_setup() - columns = self.get_columns() try: query = self.build_query( # Avoid $project (columns=None) if unneeded. - columns + self.columns if self.query.annotations or not self.query.default_cols or self.query.distinct else None ) @@ -259,10 +258,10 @@ def execute_sql( except StopIteration: return None # No result else: - return self._make_result(obj, columns) + return self._make_result(obj, self.columns) # result_type is MULTI cursor.batch_size(chunk_size) - result = self.cursor_iter(cursor, chunk_size, columns) + result = self.cursor_iter(cursor, chunk_size, self.columns) if not chunked_fetch: # If using non-chunked reads, read data into memory. return list(result) @@ -394,7 +393,8 @@ def build_query(self, columns=None): query.subqueries = self.subqueries return query - def get_columns(self): + @cached_property + def columns(self): """ Return a tuple of (name, expression) with the columns and annotations which should be loaded by the query. @@ -472,8 +472,7 @@ def get_combinator_queries(self): query.get_compiler(self.using, self.connection, self.elide_empty) for query in self.query.combined_queries ] - main_query_columns = self.get_columns() - main_query_fields, _ = zip(*main_query_columns, strict=True) + main_query_fields, _ = zip(*self.columns, strict=True) for compiler_ in compilers: try: # If the columns list is limited, then all combined queries @@ -490,7 +489,7 @@ def get_combinator_queries(self): ) compiler_.pre_sql_setup() compiler_.column_indices = self.column_indices - columns = compiler_.get_columns() + columns = compiler_.columns parts.append((compiler_.build_query(columns), compiler_, columns)) except EmptyResultSet: # Omit the empty queryset with UNION. @@ -528,7 +527,7 @@ def get_combinator_queries(self): combinator_pipeline = inner_pipeline if not self.query.combinator_all: ids = defaultdict(dict) - for alias, expr in main_query_columns: + for alias, expr in self.columns: # Unfold foreign fields. if isinstance(expr, Col) and expr.alias != self.collection_name: ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection) @@ -633,10 +632,9 @@ def explain_query(self): ) # Build the query pipeline. self.pre_sql_setup() - columns = self.get_columns() query = self.build_query( # Avoid $project (columns=None) if unneeded. - columns if self.query.annotations or not self.query.default_cols else None + self.columns if self.query.annotations or not self.query.default_cols else None ) pipeline = query.get_pipeline() # Explain the pipeline. @@ -796,7 +794,7 @@ def build_query(self, columns=None): compiler.pre_sql_setup(with_col_aliases=False) # Avoid $project (columns=None) if unneeded. columns = ( - compiler.get_columns() + compiler.columns if self.query.annotations or not self.query.default_cols or self.query.distinct else None ) diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index bf4bf23b..957c5f15 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -98,10 +98,9 @@ def order_by(self, compiler, connection): def query(self, compiler, connection, lookup_name=None): subquery_compiler = self.get_compiler(connection=connection) subquery_compiler.pre_sql_setup(with_col_aliases=False) - columns = subquery_compiler.get_columns() - field_name, expr = columns[0] + field_name, expr = subquery_compiler.columns[0] subquery = subquery_compiler.build_query( - columns + subquery_compiler.columns if subquery_compiler.query.annotations or not subquery_compiler.query.default_cols else None )