Skip to content

Commit

Permalink
Merge pull request #93 from roman-certn/main
Browse files Browse the repository at this point in the history
Add support for GenericRelation field
  • Loading branch information
j4mie authored Jan 12, 2024
2 parents e57725f + 84b9824 commit c99ae11
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Added support for Django's reverse generic relations (`GenericRelation` model field) ([#93](https://github.com/dabapps/django-readers/pull/93)).

### Changed
- Add support for Django 5.0

Expand Down
41 changes: 41 additions & 0 deletions django_readers/qs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor
from django.db.models import Prefetch, QuerySet
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.related_descriptors import (
Expand Down Expand Up @@ -167,6 +168,37 @@ def prefetch_reverse_relationship(
)


def prefetch_reverse_generic_relationship(
name,
content_type_field_name,
object_id_field_name,
related_queryset,
prepare_related_queryset=noop,
to_attr=None,
):
"""
Efficiently prefetch a reverse generic relationship: one where the field on the "parent"
queryset is a `GenericRelation` field. We need to include this field in the query.
"""
return pipe(
include_fields(name),
prefetch_related(
Prefetch(
name,
pipe(
include_fields(
"pk",
content_type_field_name,
object_id_field_name,
),
prepare_related_queryset,
)(related_queryset),
to_attr,
)
),
)


def prefetch_many_to_many_relationship(
name, related_queryset, prepare_related_queryset=noop, to_attr=None
):
Expand Down Expand Up @@ -246,5 +278,14 @@ def prepare(queryset):
prepare_related_queryset,
to_attr,
)(queryset)
if type(related_descriptor) is ReverseGenericManyToOneDescriptor:
return prefetch_reverse_generic_relationship(
name,
related_descriptor.rel.field.content_type_field_name,
related_descriptor.rel.field.object_id_field_name,
related_descriptor.field.related_model.objects.all(),
prepare_related_queryset,
to_attr,
)(queryset)

return prepare
20 changes: 18 additions & 2 deletions django_readers/rest_framework.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor
from django.core.exceptions import ImproperlyConfigured
from django.utils.functional import cached_property
from django_readers import specs
Expand Down Expand Up @@ -124,10 +125,25 @@ def _get_child_serializer_kwargs(self, rel_info):
kwargs["allow_null"] = True
return kwargs

def _get_rel_info(self, rel_name):
descriptor = getattr(self.model, rel_name)
# Special case for reverse generic relations (GenericRelation field)
# as these don't appear in rest-framework's rel_info
if isinstance(descriptor, ReverseGenericManyToOneDescriptor):
return model_meta.RelationInfo(
model_field=descriptor.field,
related_model=descriptor.field.related_model,
to_many=True,
to_field=None,
has_through_model=False,
reverse=True,
)
return self.info.relations[rel_name]

def visit_dict_item_list(self, key, value):
# This is a relationship, so we recurse and create
# a nested serializer to represent it
rel_info = self.info.relations[key]
rel_info = self._get_rel_info(key)
capfirst = self._lowercase_with_underscores_to_capitalized_words(key)
child_serializer_class = serializer_class_for_spec(
f"{self.name}{capfirst}",
Expand All @@ -143,7 +159,7 @@ def visit_dict_item_dict(self, key, value):
# do the same as the previous case, but handled
# slightly differently to set the `source` correctly
relationship_name, relationship_spec = next(iter(value.items()))
rel_info = self.info.relations[relationship_name]
rel_info = self._get_rel_info(relationship_name)
capfirst = self._lowercase_with_underscores_to_capitalized_words(key)
child_serializer_class = serializer_class_for_spec(
f"{self.name}{capfirst}",
Expand Down
14 changes: 14 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models


class LogEntry(models.Model):
content_type = models.ForeignKey(
to="contenttypes.ContentType",
on_delete=models.CASCADE,
related_name="+",
)
object_pk = models.CharField(max_length=255)
event = models.CharField(max_length=100)


class Group(models.Model):
name = models.CharField(max_length=100)

Expand All @@ -15,6 +26,9 @@ class Widget(models.Model):
value = models.PositiveIntegerField(default=0)
other = models.CharField(max_length=100, null=True)
owner = models.ForeignKey(Owner, null=True, on_delete=models.SET_NULL)
logs = GenericRelation(
LogEntry, content_type_field="content_type", object_id_field="object_pk"
)


class Thing(models.Model):
Expand Down
87 changes: 86 additions & 1 deletion tests/test_qs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from django.contrib.contenttypes.models import ContentType
from django.db import connection
from django.db.models import Count
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from django_readers import qs
from tests.models import Category, Owner, Widget
from tests.models import Category, LogEntry, Owner, Widget
from unittest import mock


Expand Down Expand Up @@ -188,6 +189,84 @@ def test_prefetch_reverse_relationship(self):
with self.assertNumQueries(0):
self.assertEqual(owners[0].widget_set.all()[0].name, "test widget")

def test_prefetch_reverse_generic_relationship(self):
widget = Widget.objects.create(name="test widget")
LogEntry.objects.create(
content_type=ContentType.objects.get_for_model(widget),
object_pk=widget.id,
event="CREATED",
)

prepare = qs.pipe(
qs.include_fields("name"),
qs.prefetch_reverse_generic_relationship(
"logs",
"content_type",
"object_pk",
LogEntry.objects.all(),
qs.include_fields("event"),
),
)

with CaptureQueriesContext(connection) as capture:
widgets = list(prepare(Widget.objects.all()))

self.assertEqual(len(capture.captured_queries), 2)

self.assertEqual(
capture.captured_queries[0]["sql"],
"SELECT "
'"tests_widget"."id", '
'"tests_widget"."name" '
"FROM "
'"tests_widget"',
)

content_type_id = ContentType.objects.get_for_model(Widget).pk

self.assertEqual(
capture.captured_queries[1]["sql"],
"SELECT "
'"tests_logentry"."id", '
'"tests_logentry"."content_type_id", '
'"tests_logentry"."object_pk", '
'"tests_logentry"."event" '
"FROM "
'"tests_logentry" '
"WHERE "
f'("tests_logentry"."content_type_id" = {content_type_id} AND '
'"tests_logentry"."object_pk" IN '
"('1'))",
)

with self.assertNumQueries(0):
self.assertEqual(widgets[0].logs.all()[0].event, "CREATED")

def test_prefetch_reverse_generic_relationship_with_to_attr(self):
widget = Widget.objects.create(name="test widget")
LogEntry.objects.create(
content_type=ContentType.objects.get_for_model(widget),
object_pk=widget.id,
event="CREATED",
)

prepare = qs.pipe(
qs.include_fields("name"),
qs.prefetch_reverse_generic_relationship(
"logs",
"content_type",
"object_pk",
LogEntry.objects.all(),
qs.include_fields("event"),
to_attr="history",
),
)

widgets = list(prepare(Widget.objects.all()))

with self.assertNumQueries(0):
self.assertEqual(widgets[0].history[0].event, "CREATED")

def test_prefetch_reverse_relationship_only_loads_pk_and_related_name_by_default(
self,
):
Expand Down Expand Up @@ -358,6 +437,12 @@ def test_auto_prefetch_relationship(self):
qs.auto_prefetch_relationship("category_set")(Widget.objects.all())
mock_fn.assert_called_once()

with mock.patch(
"django_readers.qs.prefetch_reverse_generic_relationship"
) as mock_fn:
qs.auto_prefetch_relationship("logs")(Widget.objects.all())
mock_fn.assert_called_once()

def test_annotate_only_includes_fk_by_default(self):
owner = Owner.objects.create(name="test owner")
Widget.objects.create(name="test 1", owner=owner)
Expand Down
38 changes: 33 additions & 5 deletions tests/test_rest_framework.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ImproperlyConfigured
from django.test import TestCase
from django_readers import pairs, qs
Expand All @@ -10,7 +11,7 @@
from rest_framework import serializers
from rest_framework.generics import ListAPIView, RetrieveAPIView
from rest_framework.test import APIRequestFactory
from tests.models import Category, Group, Owner, Widget
from tests.models import Category, Group, LogEntry, Owner, Widget
from textwrap import dedent


Expand All @@ -28,6 +29,7 @@ class WidgetListView(SpecMixin, ListAPIView):
},
]
},
{"logs": ["event"]},
]


Expand All @@ -53,17 +55,32 @@ class CategoryDetailView(SpecMixin, RetrieveAPIView):

class RESTFrameworkTestCase(TestCase):
def test_list(self):
Widget.objects.create(
widget = Widget.objects.create(
name="test widget",
owner=Owner.objects.create(
name="test owner", group=Group.objects.create(name="test group")
),
)
LogEntry.objects.create(
content_type=ContentType.objects.get_for_model(widget),
object_pk=widget.id,
event="CREATED",
)
LogEntry.objects.create(
content_type=ContentType.objects.get_for_model(widget),
object_pk=widget.id,
event="UPDATED",
)
LogEntry.objects.create(
content_type=ContentType.objects.get_for_model(widget),
object_pk=widget.id,
event="DELETED",
)

request = APIRequestFactory().get("/")
view = WidgetListView.as_view()

with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = view(request)

self.assertEqual(
Expand All @@ -77,6 +94,11 @@ def test_list(self):
"name": "test group",
},
},
"logs": [
{"event": "CREATED"},
{"event": "UPDATED"},
{"event": "DELETED"},
],
}
],
)
Expand Down Expand Up @@ -180,12 +202,16 @@ def test_all_relationship_types(self):
},
]
},
{
"logs": [
"event",
]
},
]
},
]

cls = serializer_class_for_spec("Owner", Owner, spec)

expected = dedent(
"""\
OwnerSerializer():
Expand All @@ -199,7 +225,9 @@ def test_all_relationship_types(self):
thing = OwnerWidgetSetThingSerializer(read_only=True):
name = CharField(max_length=100, read_only=True)
related_widget = OwnerWidgetSetThingRelatedWidgetSerializer(allow_null=True, read_only=True, source='widget'):
name = CharField(allow_null=True, max_length=100, read_only=True, required=False)"""
name = CharField(allow_null=True, max_length=100, read_only=True, required=False)
logs = OwnerWidgetSetLogsSerializer(allow_null=True, many=True, read_only=True):
event = CharField(max_length=100, read_only=True)"""
)
self.assertEqual(repr(cls()), expected)

Expand Down
12 changes: 10 additions & 2 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from django.contrib.contenttypes.models import ContentType
from django.test import TestCase
from django_readers import specs
from tests.models import Category, Group, Owner, Thing, Widget
from tests.models import Category, Group, LogEntry, Owner, Thing, Widget


class SpecTestCase(TestCase):
Expand Down Expand Up @@ -28,20 +29,26 @@ def test_relationships(self):
category = Category.objects.create(name="test category")
category.widget_set.add(widget)
Thing.objects.create(name="test thing", widget=widget)
LogEntry.objects.create(
content_type=ContentType.objects.get_for_model(widget),
object_pk=widget.id,
event="CREATED",
)

prepare, project = specs.process(
[
"name",
{"owner": ["name", {"widget_set": ["name"]}]},
{"category_set": ["name", {"widget_set": ["name"]}]},
{"thing": ["name", {"widget": ["name"]}]},
{"logs": ["event"]},
]
)

with self.assertNumQueries(0):
queryset = prepare(Widget.objects.all())

with self.assertNumQueries(7):
with self.assertNumQueries(8):
instance = queryset.first()

with self.assertNumQueries(0):
Expand All @@ -59,6 +66,7 @@ def test_relationships(self):
{"name": "test category", "widget_set": [{"name": "test widget"}]},
],
"thing": {"name": "test thing", "widget": {"name": "test widget"}},
"logs": [{"event": "CREATED"}],
},
)

Expand Down

0 comments on commit c99ae11

Please sign in to comment.