Skip to content

Commit

Permalink
fix: Add id column to association tables and restore unique constra…
Browse files Browse the repository at this point in the history
…ints (#1818)

Backported-from: main (24.03)
Backported-to: 23.09
Co-authored-by: Joongi Kim <[email protected]>
  • Loading branch information
fregataa and achimnol committed Feb 1, 2024
1 parent 5db8937 commit adfec0a
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 8 deletions.
1 change: 1 addition & 0 deletions changes/1818.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `id` column and restore incorrectly dropped unique constraints to DB association tables.
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""add_unique_constraints_to_association_tables
Revision ID: a5319bfc7d7c
Revises: caf54fcc17ab
Create Date: 2024-01-03 21:43:31.208183
"""

import sqlalchemy as sa
from alembic import op

from ai.backend.manager.models.base import GUID, mapper_registry

# revision identifiers, used by Alembic.
revision = "a5319bfc7d7c"
down_revision = "caf54fcc17ab"
branch_labels = None
depends_on = None


def upgrade():
conn = op.get_bind()

association_groups_users = sa.Table(
"association_groups_users",
mapper_registry.metadata,
sa.Column(
"id",
GUID,
primary_key=True,
nullable=False,
),
sa.Column(
"user_id",
GUID,
sa.ForeignKey("users.uuid", onupdate="CASCADE", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"group_id",
GUID,
sa.ForeignKey("groups.id", onupdate="CASCADE", ondelete="CASCADE"),
nullable=False,
),
sa.UniqueConstraint("user_id", "group_id", name="uq_association_user_id_group_id"),
extend_existing=True,
)

sgroups_for_domains = sa.Table(
"sgroups_for_domains",
mapper_registry.metadata,
sa.Column(
"id",
GUID,
primary_key=True,
nullable=False,
),
sa.Column(
"scaling_group",
sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
),
sa.Column(
"domain",
sa.ForeignKey("domains.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
),
extend_existing=True,
)

sgroups_for_groups = sa.Table(
"sgroups_for_groups",
mapper_registry.metadata,
sa.Column(
"id",
GUID,
primary_key=True,
nullable=False,
),
sa.Column(
"scaling_group",
sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
),
sa.Column(
"group",
sa.ForeignKey("groups.id", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
),
extend_existing=True,
)

sgroups_for_keypairs = sa.Table(
"sgroups_for_keypairs",
mapper_registry.metadata,
sa.Column(
"id",
GUID,
primary_key=True,
nullable=False,
),
sa.Column(
"scaling_group",
sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
),
sa.Column(
"access_key",
sa.ForeignKey("keypairs.access_key", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
),
extend_existing=True,
)

def ensure_unique(table, field_1: str, field_2: str) -> None:
# Leave only one duplicate record and delete all of it
t1 = table.alias("t1")
t2 = table.alias("t2")
subq = (
sa.select([t1.c.id])
.where(t1.c[field_1] == t2.c[field_1])
.where(t1.c[field_2] == t2.c[field_2])
.where(t1.c.id > t2.c.id)
)
delete_stmt = sa.delete(table).where(table.c.id.in_(subq))
conn.execute(delete_stmt)

ensure_unique(association_groups_users, "user_id", "group_id")
ensure_unique(sgroups_for_domains, "scaling_group", "domain")
ensure_unique(sgroups_for_groups, "scaling_group", "group")
ensure_unique(sgroups_for_keypairs, "scaling_group", "access_key")

op.create_unique_constraint(
"uq_association_user_id_group_id", "association_groups_users", ["user_id", "group_id"]
)
op.create_unique_constraint(
"uq_sgroup_domain", "sgroups_for_domains", ["scaling_group", "domain"]
)
op.create_unique_constraint(
"uq_sgroup_ugroup", "sgroups_for_groups", ["scaling_group", "group"]
)
op.create_unique_constraint(
"uq_sgroup_akey", "sgroups_for_keypairs", ["scaling_group", "access_key"]
)


def downgrade():
op.drop_constraint(
"uq_association_user_id_group_id", "association_groups_users", type_="unique"
)
op.drop_constraint("uq_sgroup_domain", "sgroups_for_domains", type_="unique")
op.drop_constraint("uq_sgroup_ugroup", "sgroups_for_groups", type_="unique")
op.drop_constraint("uq_sgroup_akey", "sgroups_for_keypairs", type_="unique")
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""add_id_columns_to_association_tables
Revision ID: caf54fcc17ab
Revises: 8b2ec7e3d22a
Create Date: 2024-01-03 21:39:50.558724
"""

import sqlalchemy as sa
from alembic import op

from ai.backend.manager.models.base import GUID

# revision identifiers, used by Alembic.
revision = "caf54fcc17ab"
down_revision = "8b2ec7e3d22a"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"association_groups_users",
sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
)
op.add_column(
"sgroups_for_domains",
sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
)
op.add_column(
"sgroups_for_groups",
sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
)
op.add_column(
"sgroups_for_keypairs",
sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
)

op.create_primary_key("pk_association_groups_users", "association_groups_users", ["id"])
op.create_primary_key("pk_sgroups_for_domains", "sgroups_for_domains", ["id"])
op.create_primary_key("pk_sgroups_for_groups", "sgroups_for_groups", ["id"])
op.create_primary_key("pk_sgroups_for_keypairs", "sgroups_for_keypairs", ["id"])


def downgrade():
op.drop_constraint("pk_association_groups_users", "association_groups_users", type_="primary")
op.drop_constraint("pk_sgroups_for_domains", "sgroups_for_domains", type_="primary")
op.drop_constraint("pk_sgroups_for_groups", "sgroups_for_groups", type_="primary")
op.drop_constraint("pk_sgroups_for_keypairs", "sgroups_for_keypairs", type_="primary")

op.drop_column("sgroups_for_keypairs", "id")
op.drop_column("sgroups_for_groups", "id")
op.drop_column("sgroups_for_domains", "id")
op.drop_column("association_groups_users", "id")
4 changes: 2 additions & 2 deletions src/ai/backend/manager/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,20 @@
association_groups_users = sa.Table(
"association_groups_users",
mapper_registry.metadata,
IDColumn(),
sa.Column(
"user_id",
GUID,
sa.ForeignKey("users.uuid", onupdate="CASCADE", ondelete="CASCADE"),
nullable=False,
primary_key=True,
),
sa.Column(
"group_id",
GUID,
sa.ForeignKey("groups.id", onupdate="CASCADE", ondelete="CASCADE"),
nullable=False,
primary_key=True,
),
sa.UniqueConstraint("user_id", "group_id", name="uq_association_user_id_group_id"),
)


Expand Down
13 changes: 7 additions & 6 deletions src/ai/backend/manager/models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from .base import (
Base,
IDColumn,
StructuredJSONObjectColumn,
batch_multiresult,
batch_result,
Expand Down Expand Up @@ -142,60 +143,60 @@ def as_trafaret(cls) -> t.Trafaret:
sgroups_for_domains = sa.Table(
"sgroups_for_domains",
mapper_registry.metadata,
IDColumn(),
sa.Column(
"scaling_group",
sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
primary_key=True,
),
sa.Column(
"domain",
sa.ForeignKey("domains.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
primary_key=True,
),
sa.UniqueConstraint("scaling_group", "domain", name="uq_sgroup_domain"),
)


sgroups_for_groups = sa.Table(
"sgroups_for_groups",
mapper_registry.metadata,
IDColumn(),
sa.Column(
"scaling_group",
sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
primary_key=True,
),
sa.Column(
"group",
sa.ForeignKey("groups.id", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
primary_key=True,
),
sa.UniqueConstraint("scaling_group", "group", name="uq_sgroup_ugroup"),
)


sgroups_for_keypairs = sa.Table(
"sgroups_for_keypairs",
mapper_registry.metadata,
IDColumn(),
sa.Column(
"scaling_group",
sa.ForeignKey("scaling_groups.name", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
primary_key=True,
),
sa.Column(
"access_key",
sa.ForeignKey("keypairs.access_key", onupdate="CASCADE", ondelete="CASCADE"),
index=True,
nullable=False,
primary_key=True,
),
sa.UniqueConstraint("scaling_group", "access_key", name="uq_sgroup_akey"),
)


Expand Down

0 comments on commit adfec0a

Please sign in to comment.