Skip to content

Commit

Permalink
Build aggregation pipeline from find query without fetch (#770)
Browse files Browse the repository at this point in the history
* Build aggregation pipeline from fin query without fetch

* fix test to work with all python versions
  • Loading branch information
roman-right authored Nov 9, 2023
1 parent 12f4a14 commit d88d118
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
24 changes: 15 additions & 9 deletions beanie/odm/queries/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,25 @@ def __init__(self, document_model: Type["DocType"]):
self.pymongo_kwargs: Dict[str, Any] = {}
self.lazy_parse = False

def prepare_find_expressions(self):
def prepare_find_expressions(self, for_aggregation: bool = False):
for_aggregation = for_aggregation or self.fetch_links
if self.document_model.get_link_fields() is not None:
for i, query in enumerate(self.find_expressions):
self.find_expressions[i] = convert_ids(
query,
doc=self.document_model,
fetch_links=self.fetch_links,
doc=self.document_model, # type: ignore
for_aggregation=for_aggregation,
)

def get_filter_query(self) -> Mapping[str, Any]:
def get_filter_query(
self, for_aggregation: bool = False
) -> Mapping[str, Any]:
"""
Returns: MongoDB filter query
"""
self.prepare_find_expressions()
self.prepare_find_expressions(for_aggregation=for_aggregation)
if self.find_expressions:
return Encoder(custom_encoders=self.encoders).encode(
And(*self.find_expressions).query
Expand Down Expand Up @@ -599,10 +602,13 @@ def _set_cache(self, data):
)

def build_aggregation_pipeline(self, *extra_stages):
aggregation_pipeline: List[Dict[str, Any]] = construct_lookup_queries(
self.document_model
)
filter_query = self.get_filter_query()
if self.fetch_links:
aggregation_pipeline: List[
Dict[str, Any]
] = construct_lookup_queries(self.document_model)
else:
aggregation_pipeline = []
filter_query = self.get_filter_query(for_aggregation=True)

if filter_query:
text_queries, non_text_queries = split_text_query(filter_query)
Expand Down
8 changes: 4 additions & 4 deletions beanie/odm/utils/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def convert_ids(
query: MappingType[str, Any], doc: "Document", fetch_links: bool
query: MappingType[str, Any], doc: "Document", for_aggregation: bool
) -> Dict[str, Any]:
# TODO add all the cases
new_query = {}
Expand All @@ -27,18 +27,18 @@ def convert_ids(
and k_splitted[0] in doc.get_link_fields().keys() # type: ignore
and k_splitted[1] == "id"
):
if fetch_links:
if for_aggregation:
new_k = f"{k_splitted[0]}._id"
else:
new_k = f"{k_splitted[0]}.$id"
else:
new_k = k
new_v: Any
if isinstance(v, Mapping):
new_v = convert_ids(v, doc, fetch_links)
new_v = convert_ids(v, doc, for_aggregation)
elif isinstance(v, list):
new_v = [
convert_ids(ele, doc, fetch_links)
convert_ids(ele, doc, for_aggregation)
if isinstance(ele, Mapping)
else ele
for ele in v
Expand Down
25 changes: 25 additions & 0 deletions tests/odm/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,3 +822,28 @@ async def test_init_reversed_order(self, db):
PersonForReversedOrderInit,
],
)


class TestBuildAggregations:
async def test_find_aggregate_without_fetch_links(self, houses):
door = await Door.find_one()
aggregation = House.find(House.door.id == door.id).aggregate(
[
{"$group": {"_id": "$height", "count": {"$sum": 1}}},
]
)
assert aggregation.get_aggregation_pipeline() == [
{"$match": {"door._id": door.id}},
{"$group": {"_id": "$height", "count": {"$sum": 1}}},
]

async def test_find_aggregate_with_fetch_links(self, houses):
door = await Door.find_one()
aggregation = House.find(
House.door.id == door.id, fetch_links=True
).aggregate(
[
{"$group": {"_id": "$height", "count": {"$sum": 1}}},
]
)
assert len(aggregation.get_aggregation_pipeline()) == 12

0 comments on commit d88d118

Please sign in to comment.