Skip to content

Commit

Permalink
more edits
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Jan 14, 2025
1 parent fae14f2 commit 07525fe
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 85 deletions.
13 changes: 4 additions & 9 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ class EmbeddedModelField(models.Field):

def __init__(self, embedded_model, *args, **kwargs):
"""
`embedded_model` is the model class of the instance that will be
stored. Like other relational fields, it may also be passed as a
string.
`embedded_model` is the model class of the instance to be stored.
Like other relational fields, it may also be passed as a string.
"""
self.embedded_model = embedded_model
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -48,8 +47,8 @@ def get_internal_type(self):
def _set_model(self, model):
"""
Resolve embedded model class once the field knows the model it belongs
to. If __init__()'s embedded_model argument is a string, resolve it to the
corresponding model class, similar to relation fields.
to. If __init__()'s embedded_model argument is a string, resolve it to
the actual model class, similar to relation fields.
"""
self._model = model
if model is not None and isinstance(self.embedded_model, str):
Expand Down Expand Up @@ -98,9 +97,6 @@ def get_db_prep_save(self, embedded_instance, connection):
f"Expected instance of type {self.embedded_model!r}, not "
f"{type(embedded_instance)!r}."
)
# Apply pre_save() and get_db_prep_save() of embedded instance
# fields, create the field => value mapping to be passed to
# storage preprocessing.
field_values = {}
add = embedded_instance._state.adding
for field in embedded_instance._meta.fields:
Expand All @@ -112,7 +108,6 @@ def get_db_prep_save(self, embedded_instance, connection):
continue
field_values[field.attname] = value
# This instance will exist in the database soon.
# TODO: Ensure that this doesn't cause race conditions.
embedded_instance._state.adding = False
return field_values

Expand Down
3 changes: 2 additions & 1 deletion django_mongodb_backend/forms/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class EmbeddedModelField(forms.MultiValueField):

def __init__(self, model, prefix, *args, **kwargs):
form_kwargs = {}
# The field must be prefixed with the name of the field.
# To avoid collisions with other fields on the form, # each subfield
# must be prefixed with the name of the field.
form_kwargs["prefix"] = prefix
self.form_kwargs = form_kwargs
self.model_form_cls = modelform_factory(model, fields="__all__")
Expand Down
75 changes: 0 additions & 75 deletions tests/schema_/test_embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,6 @@ def delete_tables(self):
table_names.remove(tbl)
connection.enable_constraint_checking()

def get_indexes(self, table):
"""
Get the indexes on the table using a new cursor.
"""
with connection.cursor() as cursor:
return [
c["columns"][0]
for c in connection.introspection.get_constraints(cursor, table).values()
if c["index"] and len(c["columns"]) == 1
]

def get_uniques(self, table):
with connection.cursor() as cursor:
return [
c["columns"][0]
for c in connection.introspection.get_constraints(cursor, table).values()
if c["unique"] and len(c["columns"]) == 1
]

def get_constraints(self, table):
"""
Get the constraints on a table using a new cursor.
Expand All @@ -93,41 +74,6 @@ def get_constraints_for_columns(self, model, columns):
constraints_for_column.append(name)
return sorted(constraints_for_column)

def check_added_field_default(
self,
schema_editor,
model,
field,
field_name,
expected_default,
cast_function=None,
):
schema_editor.add_field(model, field)
database_default = connection.database[model._meta.db_table].find_one().get(field_name)
if cast_function and type(database_default) is not type(expected_default):
database_default = cast_function(database_default)
self.assertEqual(database_default, expected_default)

def get_constraints_count(self, table, column, fk_to):
"""
Return a dict with keys 'fks', 'uniques, and 'indexes' indicating the
number of foreign keys, unique constraints, and indexes on
`table`.`column`. The `fk_to` argument is a 2-tuple specifying the
expected foreign key relationship's (table, column).
"""
with connection.cursor() as cursor:
constraints = connection.introspection.get_constraints(cursor, table)
counts = {"fks": 0, "uniques": 0, "indexes": 0}
for c in constraints.values():
if c["columns"] == [column]:
if c["foreign_key"] == fk_to:
counts["fks"] += 1
if c["unique"]:
counts["uniques"] += 1
elif c["index"]:
counts["indexes"] += 1
return counts

def assertIndexOrder(self, table, index, order):
constraints = self.get_constraints(table)
self.assertIn(index, constraints)
Expand All @@ -136,27 +82,6 @@ def assertIndexOrder(self, table, index, order):
all(val == expected for val, expected in zip(index_orders, order, strict=True))
)

def assertForeignKeyExists(self, model, column, expected_fk_table, field="id"):
"""
Fail if the FK constraint on `model.Meta.db_table`.`column` to
`expected_fk_table`.id doesn't exist.
"""
if not connection.features.can_introspect_foreign_keys:
return
constraints = self.get_constraints(model._meta.db_table)
constraint_fk = None
for details in constraints.values():
if details["columns"] == [column] and details["foreign_key"]:
constraint_fk = details["foreign_key"]
break
self.assertEqual(constraint_fk, (expected_fk_table, field))

def assertForeignKeyNotExists(self, model, column, expected_fk_table):
if not connection.features.can_introspect_foreign_keys:
return
with self.assertRaises(AssertionError):
self.assertForeignKeyExists(model, column, expected_fk_table)

def assertTableExists(self, model):
self.assertIn(model._meta.db_table, connection.introspection.table_names())

Expand Down

0 comments on commit 07525fe

Please sign in to comment.