From 15f9b422a9af54d5ef2f114cf84619e9ca5b2e18 Mon Sep 17 00:00:00 2001 From: Chris Seto Date: Tue, 29 Aug 2017 09:34:39 -0400 Subject: [PATCH] Handle .filter_is_sticky (#17) --- include/expressions.py | 2 + include/query.py | 17 ++++++-- tests/migrations/0003_auto_20170825_1341.py | 43 +++++++++++++++++++++ tests/models.py | 18 +++++++++ tests/test_include.py | 8 ++++ 5 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 tests/migrations/0003_auto_20170825_1341.py diff --git a/include/expressions.py b/include/expressions.py index e235b5d..7add904 100644 --- a/include/expressions.py +++ b/include/expressions.py @@ -174,6 +174,8 @@ def add_where(self, queryset, host_table, included_table): class IncludeExpression(Expression): + # No need to use group bys when using .include + contains_aggregate = False def __init__(self, field, children=None): expressions = [] diff --git a/include/query.py b/include/query.py index 36ff74a..e6e9345 100644 --- a/include/query.py +++ b/include/query.py @@ -99,6 +99,13 @@ def include(self, *fields, **kwargs): """ clone = self._clone() + # Preserve the stickiness of related querysets + # NOTE: by default _clone will clear this attribute + # .include does not modify the actual query, so we + # should not clear `filter_is_sticky` + if self.query.filter_is_sticky: + clone.query.filter_is_sticky = True + clone._include_limit = kwargs.pop('limit_includes', None) assert not kwargs, '"limit_includes" is the only accepted kwargs. Eat your heart out 2.7' @@ -111,7 +118,7 @@ def include(self, *fields, **kwargs): # Including multiple child fields ie .include(field1__field2, field1__field3) # turns into {field1: {field2: {}, field3: {}} for name in fields: - ctx, model = self._includes, self.model + ctx, model = clone._includes, clone.model for spl in name.split('__'): field = model._meta.get_field(spl) if isinstance(field, ForeignObjectRel) and field.is_hidden(): @@ -119,7 +126,7 @@ def include(self, *fields, **kwargs): model = field.related_model ctx = ctx.setdefault(field, OrderedDict()) - for field in self._includes.keys(): + for field in clone._includes.keys(): clone._include(field) return clone @@ -131,7 +138,11 @@ def count(self): def _clone(self): clone = super(IncludeQuerySet, self)._clone() - clone._includes = self._includes + + clone._includes = self._includes.copy() + for field in self._includes.keys(): + clone._include(field) + return clone def _include(self, field): diff --git a/tests/migrations/0003_auto_20170825_1341.py b/tests/migrations/0003_auto_20170825_1341.py new file mode 100644 index 0000000..d6a1844 --- /dev/null +++ b/tests/migrations/0003_auto_20170825_1341.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-08-25 13:41 +from __future__ import unicode_literals + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('tests', '0002_auto_20170713_1713'), + ] + + operations = [ + migrations.CreateModel( + name='Membership', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('active', models.BooleanField(default=True)), + ('joined', models.DateTimeField(auto_now_add=True)), + ('member', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='memberships', to='tests.Cat')), + ], + ), + migrations.CreateModel( + name='Organization', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('title', models.CharField(max_length=75)), + ('disliked', models.BooleanField(default=True)), + ], + ), + migrations.AddField( + model_name='membership', + name='organization', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='members', to='tests.Organization'), + ), + migrations.AddField( + model_name='cat', + name='organizations', + field=models.ManyToManyField(through='tests.Membership', to='tests.Organization'), + ), + ] diff --git a/tests/models.py b/tests/models.py index 90ad298..02baac7 100644 --- a/tests/models.py +++ b/tests/models.py @@ -31,5 +31,23 @@ class Cat(models.Model): parent = models.ForeignKey('Cat', null=True, related_name='children') siblings = models.ManyToManyField('Cat', related_name='related_to') emergency_contact = models.OneToOneField('Cat', null=True, related_name='emergency_contact_for') + organizations = models.ManyToManyField('Organization', through='Membership') + + objects = IncludeManager() + + +class Organization(models.Model): + title = models.CharField(max_length=75) + disliked = models.BooleanField(default=True) + + objects = IncludeManager() + + +class Membership(models.Model): + active = models.BooleanField(default=True) + joined = models.DateTimeField(auto_now_add=True) + + organization = models.ForeignKey(Organization, related_name='members') + member = models.ForeignKey(Cat, related_name='memberships') objects = IncludeManager() diff --git a/tests/test_include.py b/tests/test_include.py index 86f38e1..93f847b 100644 --- a/tests/test_include.py +++ b/tests/test_include.py @@ -275,3 +275,11 @@ def test_reverse_gfk(self, django_assert_num_queries): with django_assert_num_queries(1): ket = models.Cat.objects.include('aliases').get(pk=cat.id) assert len(ket.aliases.all()) == 7 + + def test_chaining(self, django_assert_num_queries): + cat = factories.CatFactory() + + qs1 = cat.organizations.all().filter(members__active=False).include('members') + qs2 = cat.organizations.all().include('members').filter(members__active=False) + + assert str(qs1.query) == str(qs2.query)