diff --git a/changes/1818.fix.md b/changes/1818.fix.md new file mode 100644 index 0000000000..9f213ddc61 --- /dev/null +++ b/changes/1818.fix.md @@ -0,0 +1 @@ +Add `id` column and restore incorrectly dropped unique constraints to DB association tables. diff --git a/src/ai/backend/manager/models/alembic/versions/a5319bfc7d7c_add_unique_constraints_to_association_.py b/src/ai/backend/manager/models/alembic/versions/a5319bfc7d7c_add_unique_constraints_to_association_.py new file mode 100644 index 0000000000..76768efab5 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/a5319bfc7d7c_add_unique_constraints_to_association_.py @@ -0,0 +1,159 @@ +"""add_unique_constraints_to_association_tables + +Revision ID: a5319bfc7d7c +Revises: caf54fcc17ab +Create Date: 2024-01-03 21:43:31.208183 + +""" + +import sqlalchemy as sa +from alembic import op + +from ai.backend.manager.models.base import GUID, mapper_registry + +# revision identifiers, used by Alembic. +revision = "a5319bfc7d7c" +down_revision = "caf54fcc17ab" +branch_labels = None +depends_on = None + + +def upgrade(): + conn = op.get_bind() + + association_groups_users = sa.Table( + "association_groups_users", + mapper_registry.metadata, + sa.Column( + "id", + GUID, + primary_key=True, + nullable=False, + ), + sa.Column( + "user_id", + GUID, + sa.ForeignKey("users.uuid", onupdate="CASCADE", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "group_id", + GUID, + sa.ForeignKey("groups.id", onupdate="CASCADE", ondelete="CASCADE"), + nullable=False, + ), + sa.UniqueConstraint("user_id", "group_id", name="uq_association_user_id_group_id"), + extend_existing=True, + ) + + sgroups_for_domains = sa.Table( + "sgroups_for_domains", + mapper_registry.metadata, + sa.Column( + "id", + GUID, + primary_key=True, + nullable=False, + ), + sa.Column( + "scaling_group", + sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"), + index=True, + nullable=False, + ), + sa.Column( + "domain", + sa.ForeignKey("domains.name", onupdate="CASCADE", ondelete="CASCADE"), + index=True, + nullable=False, + ), + extend_existing=True, + ) + + sgroups_for_groups = sa.Table( + "sgroups_for_groups", + mapper_registry.metadata, + sa.Column( + "id", + GUID, + primary_key=True, + nullable=False, + ), + sa.Column( + "scaling_group", + sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"), + index=True, + nullable=False, + ), + sa.Column( + "group", + sa.ForeignKey("groups.id", onupdate="CASCADE", ondelete="CASCADE"), + index=True, + nullable=False, + ), + extend_existing=True, + ) + + sgroups_for_keypairs = sa.Table( + "sgroups_for_keypairs", + mapper_registry.metadata, + sa.Column( + "id", + GUID, + primary_key=True, + nullable=False, + ), + sa.Column( + "scaling_group", + sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"), + index=True, + nullable=False, + ), + sa.Column( + "access_key", + sa.ForeignKey("keypairs.access_key", onupdate="CASCADE", ondelete="CASCADE"), + index=True, + nullable=False, + ), + extend_existing=True, + ) + + def ensure_unique(table, field_1: str, field_2: str) -> None: + # Leave only one duplicate record and delete all of it + t1 = table.alias("t1") + t2 = table.alias("t2") + subq = ( + sa.select([t1.c.id]) + .where(t1.c[field_1] == t2.c[field_1]) + .where(t1.c[field_2] == t2.c[field_2]) + .where(t1.c.id > t2.c.id) + ) + delete_stmt = sa.delete(table).where(table.c.id.in_(subq)) + conn.execute(delete_stmt) + + ensure_unique(association_groups_users, "user_id", "group_id") + ensure_unique(sgroups_for_domains, "scaling_group", "domain") + ensure_unique(sgroups_for_groups, "scaling_group", "group") + ensure_unique(sgroups_for_keypairs, "scaling_group", "access_key") + + op.create_unique_constraint( + "uq_association_user_id_group_id", "association_groups_users", ["user_id", "group_id"] + ) + op.create_unique_constraint( + "uq_sgroup_domain", "sgroups_for_domains", ["scaling_group", "domain"] + ) + op.create_unique_constraint( + "uq_sgroup_ugroup", "sgroups_for_groups", ["scaling_group", "group"] + ) + op.create_unique_constraint( + "uq_sgroup_akey", "sgroups_for_keypairs", ["scaling_group", "access_key"] + ) + + +def downgrade(): + op.drop_constraint( + "uq_association_user_id_group_id", "association_groups_users", type_="unique" + ) + op.drop_constraint("uq_sgroup_domain", "sgroups_for_domains", type_="unique") + op.drop_constraint("uq_sgroup_ugroup", "sgroups_for_groups", type_="unique") + op.drop_constraint("uq_sgroup_akey", "sgroups_for_keypairs", type_="unique") diff --git a/src/ai/backend/manager/models/alembic/versions/caf54fcc17ab_add_id_columns_to_association_tables.py b/src/ai/backend/manager/models/alembic/versions/caf54fcc17ab_add_id_columns_to_association_tables.py new file mode 100644 index 0000000000..4dd3869ba1 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/caf54fcc17ab_add_id_columns_to_association_tables.py @@ -0,0 +1,54 @@ +"""add_id_columns_to_association_tables + +Revision ID: caf54fcc17ab +Revises: 8b2ec7e3d22a +Create Date: 2024-01-03 21:39:50.558724 + +""" + +import sqlalchemy as sa +from alembic import op + +from ai.backend.manager.models.base import GUID + +# revision identifiers, used by Alembic. +revision = "caf54fcc17ab" +down_revision = "8b2ec7e3d22a" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "association_groups_users", + sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + ) + op.add_column( + "sgroups_for_domains", + sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + ) + op.add_column( + "sgroups_for_groups", + sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + ) + op.add_column( + "sgroups_for_keypairs", + sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + ) + + op.create_primary_key("pk_association_groups_users", "association_groups_users", ["id"]) + op.create_primary_key("pk_sgroups_for_domains", "sgroups_for_domains", ["id"]) + op.create_primary_key("pk_sgroups_for_groups", "sgroups_for_groups", ["id"]) + op.create_primary_key("pk_sgroups_for_keypairs", "sgroups_for_keypairs", ["id"]) + + +def downgrade(): + op.drop_constraint("pk_association_groups_users", "association_groups_users", type_="primary") + op.drop_constraint("pk_sgroups_for_domains", "sgroups_for_domains", type_="primary") + op.drop_constraint("pk_sgroups_for_groups", "sgroups_for_groups", type_="primary") + op.drop_constraint("pk_sgroups_for_keypairs", "sgroups_for_keypairs", type_="primary") + + op.drop_column("sgroups_for_keypairs", "id") + op.drop_column("sgroups_for_groups", "id") + op.drop_column("sgroups_for_domains", "id") + op.drop_column("association_groups_users", "id") diff --git a/src/ai/backend/manager/models/group.py b/src/ai/backend/manager/models/group.py index b9db683dd1..b39e032841 100644 --- a/src/ai/backend/manager/models/group.py +++ b/src/ai/backend/manager/models/group.py @@ -81,20 +81,20 @@ association_groups_users = sa.Table( "association_groups_users", mapper_registry.metadata, + IDColumn(), sa.Column( "user_id", GUID, sa.ForeignKey("users.uuid", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, - primary_key=True, ), sa.Column( "group_id", GUID, sa.ForeignKey("groups.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, - primary_key=True, ), + sa.UniqueConstraint("user_id", "group_id", name="uq_association_user_id_group_id"), ) diff --git a/src/ai/backend/manager/models/scaling_group.py b/src/ai/backend/manager/models/scaling_group.py index 57ea6941e1..522c8870c5 100644 --- a/src/ai/backend/manager/models/scaling_group.py +++ b/src/ai/backend/manager/models/scaling_group.py @@ -30,6 +30,7 @@ from .base import ( Base, + IDColumn, StructuredJSONObjectColumn, batch_multiresult, batch_result, @@ -142,60 +143,60 @@ def as_trafaret(cls) -> t.Trafaret: sgroups_for_domains = sa.Table( "sgroups_for_domains", mapper_registry.metadata, + IDColumn(), sa.Column( "scaling_group", sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"), index=True, nullable=False, - primary_key=True, ), sa.Column( "domain", sa.ForeignKey("domains.name", onupdate="CASCADE", ondelete="CASCADE"), index=True, nullable=False, - primary_key=True, ), + sa.UniqueConstraint("scaling_group", "domain", name="uq_sgroup_domain"), ) sgroups_for_groups = sa.Table( "sgroups_for_groups", mapper_registry.metadata, + IDColumn(), sa.Column( "scaling_group", sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"), index=True, nullable=False, - primary_key=True, ), sa.Column( "group", sa.ForeignKey("groups.id", onupdate="CASCADE", ondelete="CASCADE"), index=True, nullable=False, - primary_key=True, ), + sa.UniqueConstraint("scaling_group", "group", name="uq_sgroup_ugroup"), ) sgroups_for_keypairs = sa.Table( "sgroups_for_keypairs", mapper_registry.metadata, + IDColumn(), sa.Column( "scaling_group", sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"), index=True, nullable=False, - primary_key=True, ), sa.Column( "access_key", sa.ForeignKey("keypairs.access_key", onupdate="CASCADE", ondelete="CASCADE"), index=True, nullable=False, - primary_key=True, ), + sa.UniqueConstraint("scaling_group", "access_key", name="uq_sgroup_akey"), )