-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ListField and EmbeddedModelField tests
- Loading branch information
Showing
4 changed files
with
539 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from django_mongodb.fields import EmbeddedModelField, ListField | ||
|
||
from django.db import models | ||
|
||
|
||
def count_calls(func): | ||
|
||
def wrapper(*args, **kwargs): | ||
wrapper.calls += 1 | ||
return func(*args, **kwargs) | ||
|
||
wrapper.calls = 0 | ||
|
||
return wrapper | ||
|
||
|
||
class ReferenceList(models.Model): | ||
keys = ListField(models.ForeignKey("Model", models.CASCADE)) | ||
|
||
|
||
class Model(models.Model): | ||
pass | ||
|
||
|
||
class Target(models.Model): | ||
index = models.IntegerField() | ||
|
||
|
||
class DecimalModel(models.Model): | ||
decimal = models.DecimalField(max_digits=9, decimal_places=2) | ||
|
||
|
||
class DecimalKey(models.Model): | ||
decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True) | ||
|
||
|
||
class DecimalParent(models.Model): | ||
child = models.ForeignKey(DecimalKey, models.CASCADE) | ||
|
||
|
||
class DecimalsList(models.Model): | ||
decimals = ListField(models.ForeignKey(DecimalKey, models.CASCADE)) | ||
|
||
|
||
class OrderedListModel(models.Model): | ||
ordered_ints = ListField( | ||
models.IntegerField(max_length=500), | ||
default=[], | ||
ordering=count_calls(lambda x: x), | ||
null=True, | ||
) | ||
ordered_nullable = ListField(ordering=lambda x: x, null=True) | ||
|
||
|
||
class ListModel(models.Model): | ||
integer = models.IntegerField(primary_key=True) | ||
floating_point = models.FloatField() | ||
names = ListField(models.CharField) | ||
names_with_default = ListField(models.CharField(max_length=500), default=[]) | ||
names_nullable = ListField(models.CharField(max_length=500), null=True) | ||
|
||
|
||
class EmbeddedModelFieldModel(models.Model): | ||
simple = EmbeddedModelField("EmbeddedModel", null=True) | ||
simple_untyped = EmbeddedModelField(null=True) | ||
decimal_parent = EmbeddedModelField(DecimalParent, null=True) | ||
# typed_list = ListField(EmbeddedModelField('SetModel')) | ||
typed_list2 = ListField(EmbeddedModelField("EmbeddedModel")) | ||
untyped_list = ListField(EmbeddedModelField()) | ||
# untyped_dict = DictField(EmbeddedModelField()) | ||
ordered_list = ListField(EmbeddedModelField(), ordering=lambda obj: obj.index) | ||
|
||
|
||
class EmbeddedModel(models.Model): | ||
some_relation = models.ForeignKey(Target, models.CASCADE, null=True) | ||
someint = models.IntegerField(db_column="custom") | ||
auto_now = models.DateTimeField(auto_now=True) | ||
auto_now_add = models.DateTimeField(auto_now_add=True) | ||
|
||
|
||
class Child(models.Model): | ||
pass | ||
|
||
|
||
class Parent(models.Model): | ||
id = models.IntegerField(primary_key=True) | ||
integer_list = ListField(models.IntegerField) | ||
|
||
# integer_dict = DictField(models.IntegerField) | ||
embedded_list = ListField(EmbeddedModelField(Child)) | ||
|
||
|
||
# embedded_dict = DictField(EmbeddedModelField(Child)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import time | ||
from decimal import Decimal | ||
|
||
from django.db import models | ||
from django.test import TestCase | ||
|
||
from .models import ( | ||
Child, | ||
DecimalKey, | ||
DecimalParent, | ||
EmbeddedModel, | ||
EmbeddedModelFieldModel, | ||
OrderedListModel, | ||
Parent, | ||
Target, | ||
) | ||
|
||
|
||
class EmbeddedModelFieldTests(TestCase): | ||
|
||
def assertEqualDatetime(self, d1, d2): | ||
"""Compares d1 and d2, ignoring microseconds.""" | ||
self.assertEqual(d1.replace(microsecond=0), d2.replace(microsecond=0)) | ||
|
||
def assertNotEqualDatetime(self, d1, d2): | ||
self.assertNotEqual(d1.replace(microsecond=0), d2.replace(microsecond=0)) | ||
|
||
def _simple_instance(self): | ||
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5")) | ||
return EmbeddedModelFieldModel.objects.get() | ||
|
||
def test_simple(self): | ||
instance = self._simple_instance() | ||
self.assertIsInstance(instance.simple, EmbeddedModel) | ||
# Make sure get_prep_value is called. | ||
self.assertEqual(instance.simple.someint, 5) | ||
# Primary keys should not be populated... | ||
self.assertEqual(instance.simple.id, None) | ||
# ... unless set explicitly. | ||
instance.simple.id = instance.id | ||
instance.save() | ||
instance = EmbeddedModelFieldModel.objects.get() | ||
self.assertEqual(instance.simple.id, instance.id) | ||
|
||
def _test_pre_save(self, instance, get_field): | ||
# Make sure field.pre_save is called for embedded objects. | ||
|
||
instance.save() | ||
auto_now = get_field(instance).auto_now | ||
auto_now_add = get_field(instance).auto_now_add | ||
self.assertNotEqual(auto_now, None) | ||
self.assertNotEqual(auto_now_add, None) | ||
|
||
time.sleep(1) # FIXME | ||
instance.save() | ||
self.assertNotEqualDatetime( | ||
get_field(instance).auto_now, get_field(instance).auto_now_add | ||
) | ||
|
||
instance = EmbeddedModelFieldModel.objects.get() | ||
instance.save() | ||
# auto_now_add shouldn't have changed now, but auto_now should. | ||
self.assertEqualDatetime(get_field(instance).auto_now_add, auto_now_add) | ||
self.assertGreater(get_field(instance).auto_now, auto_now) | ||
|
||
def test_pre_save(self): | ||
obj = EmbeddedModelFieldModel(simple=EmbeddedModel()) | ||
self._test_pre_save(obj, lambda instance: instance.simple) | ||
|
||
def test_pre_save_untyped(self): | ||
obj = EmbeddedModelFieldModel(simple_untyped=EmbeddedModel()) | ||
self._test_pre_save(obj, lambda instance: instance.simple_untyped) | ||
|
||
def test_pre_save_in_list(self): | ||
obj = EmbeddedModelFieldModel(untyped_list=[EmbeddedModel()]) | ||
self._test_pre_save(obj, lambda instance: instance.untyped_list[0]) | ||
|
||
def _test_pre_save_in_dict(self): | ||
obj = EmbeddedModelFieldModel(untyped_dict={"a": EmbeddedModel()}) | ||
self._test_pre_save(obj, lambda instance: instance.untyped_dict["a"]) | ||
|
||
def test_pre_save_list(self): | ||
# Also make sure auto_now{,add} works for embedded object *lists*. | ||
EmbeddedModelFieldModel.objects.create(typed_list2=[EmbeddedModel()]) | ||
instance = EmbeddedModelFieldModel.objects.get() | ||
|
||
auto_now = instance.typed_list2[0].auto_now | ||
auto_now_add = instance.typed_list2[0].auto_now_add | ||
self.assertNotEqual(auto_now, None) | ||
self.assertNotEqual(auto_now_add, None) | ||
|
||
instance.typed_list2.append(EmbeddedModel()) | ||
instance.save() | ||
instance = EmbeddedModelFieldModel.objects.get() | ||
|
||
self.assertEqualDatetime(instance.typed_list2[0].auto_now_add, auto_now_add) | ||
self.assertGreater(instance.typed_list2[0].auto_now, auto_now) | ||
self.assertNotEqual(instance.typed_list2[1].auto_now, None) | ||
self.assertNotEqual(instance.typed_list2[1].auto_now_add, None) | ||
|
||
def test_error_messages(self): | ||
for kwargs, expected in ( | ||
({"simple": 42}, EmbeddedModel), | ||
({"simple_untyped": 42}, models.Model), | ||
# ({"typed_list": [EmbeddedModel()]},), # SetModel), | ||
): | ||
self.assertRaisesMessage( | ||
TypeError, | ||
"Expected instance of type %r" % expected, | ||
EmbeddedModelFieldModel(**kwargs).save, | ||
) | ||
|
||
def test_typed_listfield(self): | ||
EmbeddedModelFieldModel.objects.create( | ||
# typed_list=[SetModel(setfield=range(3)), SetModel(setfield=range(9))], | ||
ordered_list=[Target(index=i) for i in range(5, 0, -1)], | ||
) | ||
obj = EmbeddedModelFieldModel.objects.get() | ||
# self.assertIn(5, obj.typed_list[1].setfield) | ||
self.assertEqual([target.index for target in obj.ordered_list], range(1, 6)) | ||
|
||
def test_untyped_listfield(self): | ||
EmbeddedModelFieldModel.objects.create( | ||
untyped_list=[ | ||
EmbeddedModel(someint=7), | ||
OrderedListModel(ordered_ints=list(range(5, 0, -1))), | ||
# SetModel(setfield=[1, 2, 2, 3]), | ||
] | ||
) | ||
instances = EmbeddedModelFieldModel.objects.get().untyped_list | ||
for instance, cls in zip( | ||
instances, [EmbeddedModel, OrderedListModel] # SetModel] | ||
): | ||
self.assertIsInstance(instance, cls) | ||
self.assertNotEqual(instances[0].auto_now, None) | ||
self.assertEqual(instances[1].ordered_ints, range(1, 6)) | ||
|
||
def _test_untyped_dict(self): | ||
EmbeddedModelFieldModel.objects.create( | ||
untyped_dict={ | ||
# "a": SetModel(setfield=range(3)), | ||
# "b": DictModel(dictfield={"a": 1, "b": 2}), | ||
# "c": DictModel(dictfield={}, auto_now={"y": 1}), | ||
} | ||
) | ||
# data = EmbeddedModelFieldModel.objects.get().untyped_dict | ||
# self.assertIsInstance(data["a"], SetModel) | ||
# self.assertNotEqual(data["c"].auto_now["y"], None) | ||
|
||
def test_foreignkey_in_embedded_object(self): | ||
simple = EmbeddedModel(some_relation=Target.objects.create(index=1)) | ||
obj = EmbeddedModelFieldModel.objects.create(simple=simple) | ||
simple = EmbeddedModelFieldModel.objects.get().simple | ||
self.assertNotIn("some_relation", simple.__dict__) | ||
self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id)) | ||
self.assertIsInstance(simple.some_relation, Target) | ||
|
||
def test_embedded_field_with_foreign_conversion(self): | ||
decimal = DecimalKey.objects.create(decimal=Decimal("1.5")) | ||
decimal_parent = DecimalParent.objects.create(child=decimal) | ||
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent) | ||
|
||
def test_update(self): | ||
""" | ||
Test that update can be used on an a subset of objects | ||
containing collections of embedded instances; see issue #13. | ||
Also ensure that updated values are coerced according to | ||
collection field. | ||
""" | ||
child1 = Child.objects.create() | ||
child2 = Child.objects.create() | ||
parent = Parent.objects.create( | ||
pk=1, | ||
integer_list=[1], | ||
# integer_dict={"a": 2}, | ||
embedded_list=[child1], | ||
# embedded_dict={"a": child2}, | ||
) | ||
Parent.objects.filter(pk=1).update( | ||
integer_list=["3"], | ||
# integer_dict={"b": "3"}, | ||
embedded_list=[child2], | ||
# embedded_dict={"b": child1}, | ||
) | ||
parent = Parent.objects.get() | ||
self.assertEqual(parent.integer_list, [3]) | ||
# self.assertEqual(parent.integer_dict, {"b": 3}) | ||
self.assertEqual(parent.embedded_list, [child2]) | ||
# self.assertEqual(parent.embedded_dict, {"b": child1}) |
Oops, something went wrong.