diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f667dd..66eb37e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Added support for Django's reverse generic relations (`GenericRelation` model field) ([#93](https://github.com/dabapps/django-readers/pull/93)). + ### Changed - Add support for Django 5.0 diff --git a/django_readers/qs.py b/django_readers/qs.py index 59038f4..3cad769 100644 --- a/django_readers/qs.py +++ b/django_readers/qs.py @@ -1,3 +1,4 @@ +from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor from django.db.models import Prefetch, QuerySet from django.db.models.constants import LOOKUP_SEP from django.db.models.fields.related_descriptors import ( @@ -167,6 +168,37 @@ def prefetch_reverse_relationship( ) +def prefetch_reverse_generic_relationship( + name, + content_type_field_name, + object_id_field_name, + related_queryset, + prepare_related_queryset=noop, + to_attr=None, +): + """ + Efficiently prefetch a reverse generic relationship: one where the field on the "parent" + queryset is a `GenericRelation` field. We need to include this field in the query. + """ + return pipe( + include_fields(name), + prefetch_related( + Prefetch( + name, + pipe( + include_fields( + "pk", + content_type_field_name, + object_id_field_name, + ), + prepare_related_queryset, + )(related_queryset), + to_attr, + ) + ), + ) + + def prefetch_many_to_many_relationship( name, related_queryset, prepare_related_queryset=noop, to_attr=None ): @@ -246,5 +278,14 @@ def prepare(queryset): prepare_related_queryset, to_attr, )(queryset) + if type(related_descriptor) is ReverseGenericManyToOneDescriptor: + return prefetch_reverse_generic_relationship( + name, + related_descriptor.rel.field.content_type_field_name, + related_descriptor.rel.field.object_id_field_name, + related_descriptor.field.related_model.objects.all(), + prepare_related_queryset, + to_attr, + )(queryset) return prepare diff --git a/django_readers/rest_framework.py b/django_readers/rest_framework.py index 29b791c..de1eb22 100644 --- a/django_readers/rest_framework.py +++ b/django_readers/rest_framework.py @@ -1,4 +1,5 @@ from copy import deepcopy +from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor from django.core.exceptions import ImproperlyConfigured from django.utils.functional import cached_property from django_readers import specs @@ -124,10 +125,25 @@ def _get_child_serializer_kwargs(self, rel_info): kwargs["allow_null"] = True return kwargs + def _get_rel_info(self, rel_name): + descriptor = getattr(self.model, rel_name) + # Special case for reverse generic relations (GenericRelation field) + # as these don't appear in rest-framework's rel_info + if isinstance(descriptor, ReverseGenericManyToOneDescriptor): + return model_meta.RelationInfo( + model_field=descriptor.field, + related_model=descriptor.field.related_model, + to_many=True, + to_field=None, + has_through_model=False, + reverse=True, + ) + return self.info.relations[rel_name] + def visit_dict_item_list(self, key, value): # This is a relationship, so we recurse and create # a nested serializer to represent it - rel_info = self.info.relations[key] + rel_info = self._get_rel_info(key) capfirst = self._lowercase_with_underscores_to_capitalized_words(key) child_serializer_class = serializer_class_for_spec( f"{self.name}{capfirst}", @@ -143,7 +159,7 @@ def visit_dict_item_dict(self, key, value): # do the same as the previous case, but handled # slightly differently to set the `source` correctly relationship_name, relationship_spec = next(iter(value.items())) - rel_info = self.info.relations[relationship_name] + rel_info = self._get_rel_info(relationship_name) capfirst = self._lowercase_with_underscores_to_capitalized_words(key) child_serializer_class = serializer_class_for_spec( f"{self.name}{capfirst}", diff --git a/tests/models.py b/tests/models.py index 72b0494..12e9fbd 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,6 +1,17 @@ +from django.contrib.contenttypes.fields import GenericRelation from django.db import models +class LogEntry(models.Model): + content_type = models.ForeignKey( + to="contenttypes.ContentType", + on_delete=models.CASCADE, + related_name="+", + ) + object_pk = models.CharField(max_length=255) + event = models.CharField(max_length=100) + + class Group(models.Model): name = models.CharField(max_length=100) @@ -15,6 +26,9 @@ class Widget(models.Model): value = models.PositiveIntegerField(default=0) other = models.CharField(max_length=100, null=True) owner = models.ForeignKey(Owner, null=True, on_delete=models.SET_NULL) + logs = GenericRelation( + LogEntry, content_type_field="content_type", object_id_field="object_pk" + ) class Thing(models.Model): diff --git a/tests/test_qs.py b/tests/test_qs.py index af32920..969b940 100644 --- a/tests/test_qs.py +++ b/tests/test_qs.py @@ -1,9 +1,10 @@ +from django.contrib.contenttypes.models import ContentType from django.db import connection from django.db.models import Count from django.test import TestCase from django.test.utils import CaptureQueriesContext from django_readers import qs -from tests.models import Category, Owner, Widget +from tests.models import Category, LogEntry, Owner, Widget from unittest import mock @@ -188,6 +189,84 @@ def test_prefetch_reverse_relationship(self): with self.assertNumQueries(0): self.assertEqual(owners[0].widget_set.all()[0].name, "test widget") + def test_prefetch_reverse_generic_relationship(self): + widget = Widget.objects.create(name="test widget") + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="CREATED", + ) + + prepare = qs.pipe( + qs.include_fields("name"), + qs.prefetch_reverse_generic_relationship( + "logs", + "content_type", + "object_pk", + LogEntry.objects.all(), + qs.include_fields("event"), + ), + ) + + with CaptureQueriesContext(connection) as capture: + widgets = list(prepare(Widget.objects.all())) + + self.assertEqual(len(capture.captured_queries), 2) + + self.assertEqual( + capture.captured_queries[0]["sql"], + "SELECT " + '"tests_widget"."id", ' + '"tests_widget"."name" ' + "FROM " + '"tests_widget"', + ) + + content_type_id = ContentType.objects.get_for_model(Widget).pk + + self.assertEqual( + capture.captured_queries[1]["sql"], + "SELECT " + '"tests_logentry"."id", ' + '"tests_logentry"."content_type_id", ' + '"tests_logentry"."object_pk", ' + '"tests_logentry"."event" ' + "FROM " + '"tests_logentry" ' + "WHERE " + f'("tests_logentry"."content_type_id" = {content_type_id} AND ' + '"tests_logentry"."object_pk" IN ' + "('1'))", + ) + + with self.assertNumQueries(0): + self.assertEqual(widgets[0].logs.all()[0].event, "CREATED") + + def test_prefetch_reverse_generic_relationship_with_to_attr(self): + widget = Widget.objects.create(name="test widget") + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="CREATED", + ) + + prepare = qs.pipe( + qs.include_fields("name"), + qs.prefetch_reverse_generic_relationship( + "logs", + "content_type", + "object_pk", + LogEntry.objects.all(), + qs.include_fields("event"), + to_attr="history", + ), + ) + + widgets = list(prepare(Widget.objects.all())) + + with self.assertNumQueries(0): + self.assertEqual(widgets[0].history[0].event, "CREATED") + def test_prefetch_reverse_relationship_only_loads_pk_and_related_name_by_default( self, ): @@ -358,6 +437,12 @@ def test_auto_prefetch_relationship(self): qs.auto_prefetch_relationship("category_set")(Widget.objects.all()) mock_fn.assert_called_once() + with mock.patch( + "django_readers.qs.prefetch_reverse_generic_relationship" + ) as mock_fn: + qs.auto_prefetch_relationship("logs")(Widget.objects.all()) + mock_fn.assert_called_once() + def test_annotate_only_includes_fk_by_default(self): owner = Owner.objects.create(name="test owner") Widget.objects.create(name="test 1", owner=owner) diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index 872ebf7..546f43d 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -1,3 +1,4 @@ +from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ImproperlyConfigured from django.test import TestCase from django_readers import pairs, qs @@ -10,7 +11,7 @@ from rest_framework import serializers from rest_framework.generics import ListAPIView, RetrieveAPIView from rest_framework.test import APIRequestFactory -from tests.models import Category, Group, Owner, Widget +from tests.models import Category, Group, LogEntry, Owner, Widget from textwrap import dedent @@ -28,6 +29,7 @@ class WidgetListView(SpecMixin, ListAPIView): }, ] }, + {"logs": ["event"]}, ] @@ -53,17 +55,32 @@ class CategoryDetailView(SpecMixin, RetrieveAPIView): class RESTFrameworkTestCase(TestCase): def test_list(self): - Widget.objects.create( + widget = Widget.objects.create( name="test widget", owner=Owner.objects.create( name="test owner", group=Group.objects.create(name="test group") ), ) + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="CREATED", + ) + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="UPDATED", + ) + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="DELETED", + ) request = APIRequestFactory().get("/") view = WidgetListView.as_view() - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = view(request) self.assertEqual( @@ -77,6 +94,11 @@ def test_list(self): "name": "test group", }, }, + "logs": [ + {"event": "CREATED"}, + {"event": "UPDATED"}, + {"event": "DELETED"}, + ], } ], ) @@ -180,12 +202,16 @@ def test_all_relationship_types(self): }, ] }, + { + "logs": [ + "event", + ] + }, ] }, ] cls = serializer_class_for_spec("Owner", Owner, spec) - expected = dedent( """\ OwnerSerializer(): @@ -199,7 +225,9 @@ def test_all_relationship_types(self): thing = OwnerWidgetSetThingSerializer(read_only=True): name = CharField(max_length=100, read_only=True) related_widget = OwnerWidgetSetThingRelatedWidgetSerializer(allow_null=True, read_only=True, source='widget'): - name = CharField(allow_null=True, max_length=100, read_only=True, required=False)""" + name = CharField(allow_null=True, max_length=100, read_only=True, required=False) + logs = OwnerWidgetSetLogsSerializer(allow_null=True, many=True, read_only=True): + event = CharField(max_length=100, read_only=True)""" ) self.assertEqual(repr(cls()), expected) diff --git a/tests/test_specs.py b/tests/test_specs.py index 31d6196..8805133 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -1,6 +1,7 @@ +from django.contrib.contenttypes.models import ContentType from django.test import TestCase from django_readers import specs -from tests.models import Category, Group, Owner, Thing, Widget +from tests.models import Category, Group, LogEntry, Owner, Thing, Widget class SpecTestCase(TestCase): @@ -28,6 +29,11 @@ def test_relationships(self): category = Category.objects.create(name="test category") category.widget_set.add(widget) Thing.objects.create(name="test thing", widget=widget) + LogEntry.objects.create( + content_type=ContentType.objects.get_for_model(widget), + object_pk=widget.id, + event="CREATED", + ) prepare, project = specs.process( [ @@ -35,13 +41,14 @@ def test_relationships(self): {"owner": ["name", {"widget_set": ["name"]}]}, {"category_set": ["name", {"widget_set": ["name"]}]}, {"thing": ["name", {"widget": ["name"]}]}, + {"logs": ["event"]}, ] ) with self.assertNumQueries(0): queryset = prepare(Widget.objects.all()) - with self.assertNumQueries(7): + with self.assertNumQueries(8): instance = queryset.first() with self.assertNumQueries(0): @@ -59,6 +66,7 @@ def test_relationships(self): {"name": "test category", "widget_set": [{"name": "test widget"}]}, ], "thing": {"name": "test thing", "widget": {"name": "test widget"}}, + "logs": [{"event": "CREATED"}], }, )