diff --git a/changes/1705.fix.md b/changes/1705.fix.md new file mode 100644 index 0000000000..ff0b0e20e6 --- /dev/null +++ b/changes/1705.fix.md @@ -0,0 +1 @@ +Refactor group mutation by adding `users_to_add` and `users_to_remove` fields to replace `user_update_mode`. diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index 6cd6b7182c..b1fcbcbef6 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -1139,8 +1139,14 @@ input ModifyGroupInput { is_active: Boolean domain_name: String total_resource_slots: JSONString - user_update_mode: String - user_uuids: [String] + user_update_mode: String @deprecated(reason: "Deprecated since 24.09.0. Use `users_to_add` and `users_to_remove` fields") + user_uuids: [String] @deprecated(reason: "Deprecated since 24.09.0. Use `users_to_add` and `users_to_remove` fields") + + """Added in 24.09.0. ID array of the users to be added to the group.""" + users_to_add: [String] + + """Added in 24.09.0. ID array of the users to be removed from the group.""" + users_to_remove: [String] allowed_vfolder_hosts: JSONString integration_id: String resource_policy: String diff --git a/src/ai/backend/manager/models/group.py b/src/ai/backend/manager/models/group.py index 3ed98d0bd3..a18e7ed200 100644 --- a/src/ai/backend/manager/models/group.py +++ b/src/ai/backend/manager/models/group.py @@ -22,7 +22,6 @@ import sqlalchemy as sa import trafaret as t from graphene.types.datetime import DateTime as GQLDateTime -from graphql import Undefined from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection from sqlalchemy.orm import relationship @@ -61,7 +60,7 @@ from .minilang.ordering import QueryOrderParser from .minilang.queryfilter import QueryFilterParser from .storage import StorageSessionManager -from .user import ModifyUserInput, UserConnection, UserNode, UserRole +from .user import UserConnection, UserNode, UserRole from .utils import ExtendedAsyncSAEngine, execute_with_retry if TYPE_CHECKING: @@ -451,8 +450,27 @@ class ModifyGroupInput(graphene.InputObjectType): is_active = graphene.Boolean(required=False) domain_name = graphene.String(required=False) total_resource_slots = graphene.JSONString(required=False) - user_update_mode = graphene.String(required=False) - user_uuids = graphene.List(lambda: graphene.String, required=False) + user_update_mode = graphene.String( + deprecation_reason=( + "Deprecated since 24.09.0. Use `users_to_add` and `users_to_remove` fields" + ) + ) + user_uuids = graphene.List( + lambda: graphene.String, + deprecation_reason=( + "Deprecated since 24.09.0. Use `users_to_add` and `users_to_remove` fields" + ), + ) + users_to_add = graphene.List( + lambda: graphene.String, + required=False, + description="Added in 24.09.0. ID array of the users to be added to the group.", + ) + users_to_remove = graphene.List( + lambda: graphene.String, + required=False, + description="Added in 24.09.0. ID array of the users to be removed from the group.", + ) allowed_vfolder_hosts = graphene.JSONString(required=False) integration_id = graphene.String(required=False) resource_policy = graphene.String(required=False) @@ -530,10 +548,11 @@ async def mutate( root, info: graphene.ResolveInfo, gid: uuid.UUID, - props: ModifyUserInput, + props: ModifyGroupInput, ) -> ModifyGroup: graph_ctx: GraphQueryContext = info.context data: Dict[str, Any] = {} + user_data: dict[str, set] = {} set_if_set(props, data, "name") set_if_set(props, data, "description") set_if_set(props, data, "is_active") @@ -549,33 +568,43 @@ async def mutate( set_if_set(props, data, "resource_policy") set_if_set(props, data, "container_registry") + set_if_set(props, user_data, "users_to_add", clean_func=set) + set_if_set(props, user_data, "users_to_remove", clean_func=set) + if "name" in data and _rx_slug.search(data["name"]) is None: raise ValueError("invalid name format. slug format required.") - if props.user_update_mode not in (None, Undefined, "add", "remove"): - raise ValueError("invalid user_update_mode") - if not props.user_uuids: - props.user_update_mode = None - if not data and props.user_update_mode is None: + if not data and not user_data: return cls(ok=False, msg="nothing to update", group=None) async def _do_mutate() -> ModifyGroup: - async with graph_ctx.db.begin() as conn: - # TODO: refactor user addition/removal in groups as separate mutations - # (to apply since 21.09) - if props.user_update_mode == "add": - values = [{"user_id": uuid, "group_id": gid} for uuid in props.user_uuids] - await conn.execute( - sa.insert(association_groups_users).values(values), + async with graph_ctx.db.begin_session() as db_session: + users_to_add = user_data.get("users_to_add", set()) + users_to_remove = user_data.get("users_to_remove", set()) + if union := (users_to_add & users_to_remove): + raise ValueError( + "Should be no user IDs included in both `users_to_add` and `users_to_remove`." + f" (IDs: {list(union)})" ) - elif props.user_update_mode == "remove": - await conn.execute( - sa.delete(association_groups_users).where( - (association_groups_users.c.user_id.in_(props.user_uuids)) - & (association_groups_users.c.group_id == gid), - ), + if users_to_add: + values = [{"user_id": uuid, "group_id": gid} for uuid in users_to_add] + try: + await db_session.execute( + sa.insert(association_groups_users).values(values), + ) + except sa.exc.IntegrityError: + raise ValueError("User already belongs to the given project(user group)") + if users_to_remove: + await db_session.execute( + ( + sa.delete(association_groups_users).where( + (association_groups_users.c.group_id == gid) + & (association_groups_users.c.user_id.in_(users_to_remove)) + ) + ) ) + if data: - result = await conn.execute( + result = await db_session.execute( sa.update(groups).values(data).where(groups.c.id == gid).returning(groups), ) if result.rowcount > 0: