diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 214dcf4b..01be7da0 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -11,9 +11,8 @@ class EmbeddedModelField(models.Field): def __init__(self, embedded_model, *args, **kwargs): """ - `embedded_model` is the model class of the instance that will be - stored. Like other relational fields, it may also be passed as a - string. + `embedded_model` is the model class of the instance to be stored. + Like other relational fields, it may also be passed as a string. """ self.embedded_model = embedded_model super().__init__(*args, **kwargs) @@ -48,8 +47,8 @@ def get_internal_type(self): def _set_model(self, model): """ Resolve embedded model class once the field knows the model it belongs - to. If __init__()'s embedded_model argument is a string, resolve it to the - corresponding model class, similar to relation fields. + to. If __init__()'s embedded_model argument is a string, resolve it to + the actual model class, similar to relation fields. """ self._model = model if model is not None and isinstance(self.embedded_model, str): @@ -98,9 +97,6 @@ def get_db_prep_save(self, embedded_instance, connection): f"Expected instance of type {self.embedded_model!r}, not " f"{type(embedded_instance)!r}." ) - # Apply pre_save() and get_db_prep_save() of embedded instance - # fields, create the field => value mapping to be passed to - # storage preprocessing. field_values = {} add = embedded_instance._state.adding for field in embedded_instance._meta.fields: @@ -112,7 +108,6 @@ def get_db_prep_save(self, embedded_instance, connection): continue field_values[field.attname] = value # This instance will exist in the database soon. - # TODO: Ensure that this doesn't cause race conditions. embedded_instance._state.adding = False return field_values diff --git a/django_mongodb_backend/forms/fields/embedded_model.py b/django_mongodb_backend/forms/fields/embedded_model.py index 2481b7f1..4bf211d2 100644 --- a/django_mongodb_backend/forms/fields/embedded_model.py +++ b/django_mongodb_backend/forms/fields/embedded_model.py @@ -34,7 +34,8 @@ class EmbeddedModelField(forms.MultiValueField): def __init__(self, model, prefix, *args, **kwargs): form_kwargs = {} - # The field must be prefixed with the name of the field. + # To avoid collisions with other fields on the form, # each subfield + # must be prefixed with the name of the field. form_kwargs["prefix"] = prefix self.form_kwargs = form_kwargs self.model_form_cls = modelform_factory(model, fields="__all__") diff --git a/tests/schema_/test_embedded_model.py b/tests/schema_/test_embedded_model.py index bfab04d4..faf58ce2 100644 --- a/tests/schema_/test_embedded_model.py +++ b/tests/schema_/test_embedded_model.py @@ -59,25 +59,6 @@ def delete_tables(self): table_names.remove(tbl) connection.enable_constraint_checking() - def get_indexes(self, table): - """ - Get the indexes on the table using a new cursor. - """ - with connection.cursor() as cursor: - return [ - c["columns"][0] - for c in connection.introspection.get_constraints(cursor, table).values() - if c["index"] and len(c["columns"]) == 1 - ] - - def get_uniques(self, table): - with connection.cursor() as cursor: - return [ - c["columns"][0] - for c in connection.introspection.get_constraints(cursor, table).values() - if c["unique"] and len(c["columns"]) == 1 - ] - def get_constraints(self, table): """ Get the constraints on a table using a new cursor. @@ -93,41 +74,6 @@ def get_constraints_for_columns(self, model, columns): constraints_for_column.append(name) return sorted(constraints_for_column) - def check_added_field_default( - self, - schema_editor, - model, - field, - field_name, - expected_default, - cast_function=None, - ): - schema_editor.add_field(model, field) - database_default = connection.database[model._meta.db_table].find_one().get(field_name) - if cast_function and type(database_default) is not type(expected_default): - database_default = cast_function(database_default) - self.assertEqual(database_default, expected_default) - - def get_constraints_count(self, table, column, fk_to): - """ - Return a dict with keys 'fks', 'uniques, and 'indexes' indicating the - number of foreign keys, unique constraints, and indexes on - `table`.`column`. The `fk_to` argument is a 2-tuple specifying the - expected foreign key relationship's (table, column). - """ - with connection.cursor() as cursor: - constraints = connection.introspection.get_constraints(cursor, table) - counts = {"fks": 0, "uniques": 0, "indexes": 0} - for c in constraints.values(): - if c["columns"] == [column]: - if c["foreign_key"] == fk_to: - counts["fks"] += 1 - if c["unique"]: - counts["uniques"] += 1 - elif c["index"]: - counts["indexes"] += 1 - return counts - def assertIndexOrder(self, table, index, order): constraints = self.get_constraints(table) self.assertIn(index, constraints) @@ -136,27 +82,6 @@ def assertIndexOrder(self, table, index, order): all(val == expected for val, expected in zip(index_orders, order, strict=True)) ) - def assertForeignKeyExists(self, model, column, expected_fk_table, field="id"): - """ - Fail if the FK constraint on `model.Meta.db_table`.`column` to - `expected_fk_table`.id doesn't exist. - """ - if not connection.features.can_introspect_foreign_keys: - return - constraints = self.get_constraints(model._meta.db_table) - constraint_fk = None - for details in constraints.values(): - if details["columns"] == [column] and details["foreign_key"]: - constraint_fk = details["foreign_key"] - break - self.assertEqual(constraint_fk, (expected_fk_table, field)) - - def assertForeignKeyNotExists(self, model, column, expected_fk_table): - if not connection.features.can_introspect_foreign_keys: - return - with self.assertRaises(AssertionError): - self.assertForeignKeyExists(model, column, expected_fk_table) - def assertTableExists(self, model): self.assertIn(model._meta.db_table, connection.introspection.table_names())