Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prevent the creation of embedded models #227

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@ def __init__(self, embedded_model, *args, **kwargs):
super().__init__(*args, **kwargs)

def check(self, **kwargs):
from ..models import EmbeddedModel

errors = super().check(**kwargs)
if not issubclass(self.embedded_model, EmbeddedModel):
return [
checks.Error(
"Embedded models must be a subclass of "
"django_mongodb_backend.models.EmbeddedModel.",
obj=self,
id="django_mongodb_backend.embedded_model.E002",
)
]
for field in self.embedded_model._meta.fields:
if field.remote_field:
errors.append(
Expand All @@ -27,7 +38,7 @@ def check(self, **kwargs):
f"({self.embedded_model().__class__.__name__}.{field.name} "
f"is a {field.__class__.__name__}).",
obj=self,
id="django_mongodb.embedded_model.E001",
id="django_mongodb_backend.embedded_model.E001",
)
)
return errors
Expand Down
30 changes: 30 additions & 0 deletions django_mongodb_backend/managers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
from django.db import NotSupportedError
from django.db.models.manager import BaseManager

from .queryset import MongoQuerySet


class MongoManager(BaseManager.from_queryset(MongoQuerySet)):
pass


class EmbeddedModelManager(BaseManager):
"""
Prevent all queryset operations on embedded models since they don't have
their own collection.

Raise a helpful error message for some basic QuerySet methods. Subclassing
BaseManager means that other methods raise, e.g. AttributeError:
'EmbeddedModelManager' object has no attribute 'update_or_create'".
"""

def all(self):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def get(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def filter(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def create(self, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be created.")

def update(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be updated.")

def delete(self):
raise NotSupportedError("EmbeddedModels cannot be deleted.")
16 changes: 16 additions & 0 deletions django_mongodb_backend/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.db import NotSupportedError, models

from .managers import EmbeddedModelManager


class EmbeddedModel(models.Model):
objects = EmbeddedModelManager()

class Meta:
abstract = True

def delete(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be deleted.")

def save(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be saved.")
30 changes: 30 additions & 0 deletions django_mongodb_backend/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
from .utils import OperationCollector


def ignore_embedded_models(func):
"""
Make a SchemaEditor method a no-op if model is an EmbeddedModel (unless
parent_model isn't None, in which case this is a valid recursive operation
such as adding an index on an embedded model's field).
"""

def wrapper(self, model, *args, **kwargs):
parent_model = kwargs.get("parent_model")
from .models import EmbeddedModel

if issubclass(model, EmbeddedModel) and parent_model is None:
return
func(self, model, *args, **kwargs)

return wrapper


class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def get_collection(self, name):
if self.collect_sql:
Expand All @@ -22,6 +40,7 @@ def get_database(self):
return self.connection.get_database()

@wrap_database_errors
@ignore_embedded_models
def create_model(self, model):
self.get_database().create_collection(model._meta.db_table)
self._create_model_indexes(model)
Expand Down Expand Up @@ -75,13 +94,15 @@ def _create_model_indexes(self, model, column_prefix="", parent_model=None):
for index in model._meta.indexes:
self.add_index(model, index, column_prefix=column_prefix, parent_model=parent_model)

@ignore_embedded_models
def delete_model(self, model):
# Delete implicit M2m tables.
for field in model._meta.local_many_to_many:
if field.remote_field.through._meta.auto_created:
self.delete_model(field.remote_field.through)
self.get_collection(model._meta.db_table).drop()

@ignore_embedded_models
def add_field(self, model, field):
# Create implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created:
Expand All @@ -103,6 +124,7 @@ def add_field(self, model, field):
elif self._field_should_have_unique(field):
self._add_field_unique(model, field)

@ignore_embedded_models
def _alter_field(
self,
model,
Expand Down Expand Up @@ -149,6 +171,7 @@ def _alter_field(
if not old_field_unique and new_field_unique:
self._add_field_unique(model, new_field)

@ignore_embedded_models
def remove_field(self, model, field):
# Remove implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created:
Expand Down Expand Up @@ -210,6 +233,7 @@ def _remove_model_indexes(self, model, column_prefix="", parent_model=None):
for index in model._meta.indexes:
self.remove_index(parent_model or model, index)

@ignore_embedded_models
def alter_index_together(self, model, old_index_together, new_index_together, column_prefix=""):
olds = {tuple(fields) for fields in old_index_together}
news = {tuple(fields) for fields in new_index_together}
Expand All @@ -222,6 +246,7 @@ def alter_index_together(self, model, old_index_together, new_index_together, co
for field_names in news.difference(olds):
self._add_composed_index(model, field_names, column_prefix=column_prefix)

@ignore_embedded_models
def alter_unique_together(
self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None
):
Expand Down Expand Up @@ -249,6 +274,7 @@ def alter_unique_together(
model, constraint, parent_model=parent_model, column_prefix=column_prefix
)

@ignore_embedded_models
def add_index(
self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None
):
Expand Down Expand Up @@ -302,6 +328,7 @@ def _add_field_index(self, model, field, *, column_prefix=""):
index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column])
self.add_index(model, index, field=field, column_prefix=column_prefix)

@ignore_embedded_models
def remove_index(self, model, index):
if index.contains_expressions:
return
Expand Down Expand Up @@ -355,6 +382,7 @@ def _remove_field_index(self, model, field, column_prefix=""):
)
collection.drop_index(index_names[0])

@ignore_embedded_models
def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
condition=constraint.condition,
Expand Down Expand Up @@ -384,6 +412,7 @@ def _add_field_unique(self, model, field, column_prefix=""):
constraint = UniqueConstraint(fields=[field.name], name=name)
self.add_constraint(model, constraint, field=field, column_prefix=column_prefix)

@ignore_embedded_models
def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
condition=constraint.condition,
Expand Down Expand Up @@ -417,6 +446,7 @@ def _remove_field_unique(self, model, field, column_prefix=""):
)
self.get_collection(model._meta.db_table).drop_index(constraint_names[0])

@ignore_embedded_models
def alter_db_table(self, model, old_db_table, new_db_table):
if old_db_table == new_db_table:
return
Expand Down
3 changes: 2 additions & 1 deletion docs/source/embedded-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ The basics
Let's consider this example::

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

class Customer(models.Model):
name = models.CharField(...)
address = EmbeddedModelField("Address")
...

class Address(models.Model):
class Address(EmbeddedModel):
...
city = models.CharField(...)

Expand Down
12 changes: 8 additions & 4 deletions docs/source/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,11 @@ Stores a model of type ``embedded_model``.

This is a required argument.

Specifies the model class to embed. It can be either a concrete model
class or a :ref:`lazy reference <lazy-relationships>` to a model class.
Specifies the model class to embed. It must be a subclass of
:class:`django_mongodb_backend.models.EmbeddedModel`.

It can be either a concrete model class or a :ref:`lazy reference
<lazy-relationships>` to a model class.

The embedded model cannot have relational fields
(:class:`~django.db.models.ForeignKey`,
Expand All @@ -234,11 +237,12 @@ Stores a model of type ``embedded_model``.

from django.db import models
from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

class Address(models.Model):
class Address(EmbeddedModel):
...

class Author(models.Model):
class Author(EmbeddedModel):
address = EmbeddedModelField(Address)

class Book(models.Model):
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ django-mongodb-backend 5.0.x documentation
fields
querysets
forms
models
embedded-models

Indices and tables
Expand Down
15 changes: 15 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Model reference
===============

.. module:: django_mongodb_backend.models

One MongoDB-specific model is available in ``django_mongodb_backend.models``.

.. class:: EmbeddedModel

An abstract model which all :doc:`embedded models <embedded-models>` must
subclass.

Since these models are not stored in their own collection, they do not have
any of the normal ``QuerySet`` methods (``all()``, ``filter()``, ``delete()``,
etc.) You also cannot call ``Model.save()`` and ``delete()`` on them.
7 changes: 4 additions & 3 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db import models

from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField
from django_mongodb_backend.models import EmbeddedModel


# ObjectIdField
Expand Down Expand Up @@ -98,19 +99,19 @@ class Holder(models.Model):
data = EmbeddedModelField("Data", null=True, blank=True)


class Data(models.Model):
class Data(EmbeddedModel):
integer = models.IntegerField(db_column="custom_column")
auto_now = models.DateTimeField(auto_now=True)
auto_now_add = models.DateTimeField(auto_now_add=True)


class Address(models.Model):
class Address(EmbeddedModel):
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
zip_code = models.IntegerField(db_index=True)


class Author(models.Model):
class Author(EmbeddedModel):
name = models.CharField(max_length=10)
age = models.IntegerField()
address = EmbeddedModelField(Address)
Expand Down
25 changes: 20 additions & 5 deletions tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.test.utils import isolate_apps

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

from .models import (
Address,
Expand Down Expand Up @@ -108,18 +109,32 @@ def test_nested(self):
@isolate_apps("model_fields_")
class CheckTests(SimpleTestCase):
def test_no_relational_fields(self):
class Target(models.Model):
class Target(EmbeddedModel):
key = models.ForeignKey("MyModel", models.CASCADE)

class MyModel(models.Model):
field = EmbeddedModelField(Target)

model = MyModel()
errors = model.check()
errors = MyModel().check()
self.assertEqual(len(errors), 1)
# The inner CharField has a non-positive max_length.
self.assertEqual(errors[0].id, "django_mongodb.embedded_model.E001")
self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E001")
msg = errors[0].msg
self.assertEqual(
msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)."
)

def test_embedded_model_subclass(self):
class Target(models.Model):
pass

class MyModel(models.Model):
field = EmbeddedModelField(Target)

errors = MyModel().check()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E002")
msg = errors[0].msg
self.assertEqual(
msg,
"Embedded models must be a subclass of django_mongodb_backend.models.EmbeddedModel.",
)
8 changes: 2 additions & 6 deletions tests/model_forms_/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from django.db import models

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel


class Address(models.Model):
class Address(EmbeddedModel):
po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box")
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
Expand All @@ -15,8 +16,3 @@ class Author(models.Model):
age = models.IntegerField()
address = EmbeddedModelField(Address)
billing_address = EmbeddedModelField(Address, blank=True, null=True)


class Book(models.Model):
name = models.CharField(max_length=100)
author = EmbeddedModelField(Author)
Empty file added tests/models_/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions tests/models_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django_mongodb_backend.models import EmbeddedModel


class Embed(EmbeddedModel):
pass
Loading
Loading