Skip to content

Commit

Permalink
Handle .filter_is_sticky (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisseto authored Aug 29, 2017
1 parent 69bf480 commit 15f9b42
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def add_where(self, queryset, host_table, included_table):


class IncludeExpression(Expression):
# No need to use group bys when using .include
contains_aggregate = False

def __init__(self, field, children=None):
expressions = []
Expand Down
17 changes: 14 additions & 3 deletions include/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def include(self, *fields, **kwargs):
"""
clone = self._clone()

# Preserve the stickiness of related querysets
# NOTE: by default _clone will clear this attribute
# .include does not modify the actual query, so we
# should not clear `filter_is_sticky`
if self.query.filter_is_sticky:
clone.query.filter_is_sticky = True

clone._include_limit = kwargs.pop('limit_includes', None)
assert not kwargs, '"limit_includes" is the only accepted kwargs. Eat your heart out 2.7'

Expand All @@ -111,15 +118,15 @@ def include(self, *fields, **kwargs):
# Including multiple child fields ie .include(field1__field2, field1__field3)
# turns into {field1: {field2: {}, field3: {}}
for name in fields:
ctx, model = self._includes, self.model
ctx, model = clone._includes, clone.model
for spl in name.split('__'):
field = model._meta.get_field(spl)
if isinstance(field, ForeignObjectRel) and field.is_hidden():
raise ValueError('Hidden field "{!r}" has no descriptor and therefore cannot be included'.format(field))
model = field.related_model
ctx = ctx.setdefault(field, OrderedDict())

for field in self._includes.keys():
for field in clone._includes.keys():
clone._include(field)

return clone
Expand All @@ -131,7 +138,11 @@ def count(self):

def _clone(self):
clone = super(IncludeQuerySet, self)._clone()
clone._includes = self._includes

clone._includes = self._includes.copy()
for field in self._includes.keys():
clone._include(field)

return clone

def _include(self, field):
Expand Down
43 changes: 43 additions & 0 deletions tests/migrations/0003_auto_20170825_1341.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-08-25 13:41
from __future__ import unicode_literals

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
('tests', '0002_auto_20170713_1713'),
]

operations = [
migrations.CreateModel(
name='Membership',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('active', models.BooleanField(default=True)),
('joined', models.DateTimeField(auto_now_add=True)),
('member', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='memberships', to='tests.Cat')),
],
),
migrations.CreateModel(
name='Organization',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.CharField(max_length=75)),
('disliked', models.BooleanField(default=True)),
],
),
migrations.AddField(
model_name='membership',
name='organization',
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='members', to='tests.Organization'),
),
migrations.AddField(
model_name='cat',
name='organizations',
field=models.ManyToManyField(through='tests.Membership', to='tests.Organization'),
),
]
18 changes: 18 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,23 @@ class Cat(models.Model):
parent = models.ForeignKey('Cat', null=True, related_name='children')
siblings = models.ManyToManyField('Cat', related_name='related_to')
emergency_contact = models.OneToOneField('Cat', null=True, related_name='emergency_contact_for')
organizations = models.ManyToManyField('Organization', through='Membership')

objects = IncludeManager()


class Organization(models.Model):
title = models.CharField(max_length=75)
disliked = models.BooleanField(default=True)

objects = IncludeManager()


class Membership(models.Model):
active = models.BooleanField(default=True)
joined = models.DateTimeField(auto_now_add=True)

organization = models.ForeignKey(Organization, related_name='members')
member = models.ForeignKey(Cat, related_name='memberships')

objects = IncludeManager()
8 changes: 8 additions & 0 deletions tests/test_include.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,11 @@ def test_reverse_gfk(self, django_assert_num_queries):
with django_assert_num_queries(1):
ket = models.Cat.objects.include('aliases').get(pk=cat.id)
assert len(ket.aliases.all()) == 7

def test_chaining(self, django_assert_num_queries):
cat = factories.CatFactory()

qs1 = cat.organizations.all().filter(members__active=False).include('members')
qs2 = cat.organizations.all().include('members').filter(members__active=False)

assert str(qs1.query) == str(qs2.query)

0 comments on commit 15f9b42

Please sign in to comment.