Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Add users_to_add and users_to_remove fields to replace user update mode #1705

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
28bf838
revert functioned description and deprecated msg
fregataa Nov 8, 2023
55d74b8
impl user add/delete group mutation
fregataa Nov 8, 2023
b0cd2a4
replace input type of ID from UUID to String
fregataa Nov 8, 2023
e1528e8
keep the modify function alive
fregataa Nov 8, 2023
4f19cbc
change mutation name typo
fregataa Nov 8, 2023
00a9b3e
Merge branch 'main' into fix/refactor-group-graphql
fregataa Nov 14, 2023
175a5fd
Merge branch 'main' into fix/refactor-group-graphql
fregataa Nov 14, 2023
c75db1e
impl user add/remove in ModifyGroup mutation and revert description m…
fregataa Nov 14, 2023
fb68733
Merge branch 'main' into fix/refactor-group-graphql
fregataa Nov 23, 2023
b0c85d3
add news fragment
fregataa Nov 24, 2023
0760819
change added versions of some fields
fregataa Nov 24, 2023
fed0bfe
update news fragment
fregataa Nov 24, 2023
b446de3
Merge branch 'main' into fix/refactor-group-graphql
fregataa Jan 3, 2024
b5c76cc
fix wrong code and change naming
fregataa Jan 3, 2024
84a8f13
Merge branch 'main' into fix/refactor-group-graphql
fregataa Apr 18, 2024
7d8c88f
update version notation
fregataa Apr 18, 2024
20b3787
remove use of deprecated fields
fregataa Apr 18, 2024
c6f9116
little change of error msg
fregataa Apr 18, 2024
999736b
chore: update GraphQL schema dump
fregataa Apr 18, 2024
f62c10e
Merge branch 'main' into fix/refactor-group-graphql
fregataa Apr 23, 2024
a4f5e04
Merge remote-tracking branch 'origin/fix/refactor-group-graphql' into…
fregataa Apr 23, 2024
278d878
update gql fields description
fregataa Apr 23, 2024
0820ce5
chore: update GraphQL schema dump
fregataa Apr 23, 2024
81c8cdb
Merge branch 'main' into fix/refactor-group-graphql
fregataa May 1, 2024
f4ee3bb
Merge remote-tracking branch 'origin/fix/refactor-group-graphql' into…
fregataa May 1, 2024
f39454a
Merge branch 'main' into fix/refactor-group-graphql
fregataa May 8, 2024
4352885
rename fields
fregataa May 8, 2024
a9c05ed
chore: update GraphQL schema dump
fregataa May 8, 2024
4a17f83
Merge branch 'main' into fix/refactor-group-graphql
fregataa May 9, 2024
66a6880
minor update for pythonic code
fregataa May 9, 2024
62c65a7
Merge remote-tracking branch 'origin/fix/refactor-group-graphql' into…
fregataa May 9, 2024
9e068fd
update news fragment
fregataa May 10, 2024
5621850
handle duplicate user associating to group
fregataa May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/1705.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor group mutation by adding `added_users` and `removed_users` fields to replace `user_update_mode`.
63 changes: 52 additions & 11 deletions src/ai/backend/manager/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
simple_db_mutate_returning_item,
)
from .storage import StorageSessionManager
from .user import ModifyUserInput, UserRole
from .user import UserRole
from .utils import ExtendedAsyncSAEngine, execute_with_retry

if TYPE_CHECKING:
Expand Down Expand Up @@ -393,8 +393,27 @@ class ModifyGroupInput(graphene.InputObjectType):
is_active = graphene.Boolean(required=False)
domain_name = graphene.String(required=False)
total_resource_slots = graphene.JSONString(required=False)
user_update_mode = graphene.String(required=False)
user_uuids = graphene.List(lambda: graphene.String, required=False)
user_update_mode = graphene.String(
deprecation_reason=(
"Deprecated since 24.03.0. Recommend to use `added_users` and `removed_users` fields"
)
)
user_uuids = graphene.List(
lambda: graphene.String,
deprecation_reason=(
"Deprecated since 24.03.0. Recommend to use `added_users` and `removed_users` fields"
),
)
added_users = graphene.List(
lambda: graphene.String,
required=False,
description="Added since 24.03.0. ID array of the users to be added to the group.",
)
removed_users = graphene.List(
lambda: graphene.String,
required=False,
description="Added since 24.03.0. ID array of the users to be removed from the group.",
)
allowed_vfolder_hosts = graphene.JSONString(required=False)
integration_id = graphene.String(required=False)
resource_policy = graphene.String(required=False)
Expand Down Expand Up @@ -467,10 +486,11 @@ async def mutate(
root,
info: graphene.ResolveInfo,
gid: uuid.UUID,
props: ModifyUserInput,
props: ModifyGroupInput,
) -> ModifyGroup:
graph_ctx: GraphQueryContext = info.context
data: Dict[str, Any] = {}
user_data: dict[str, set] = {}
set_if_set(props, data, "name")
set_if_set(props, data, "description")
set_if_set(props, data, "is_active")
Expand All @@ -485,33 +505,54 @@ async def mutate(
set_if_set(props, data, "integration_id")
set_if_set(props, data, "resource_policy")

set_if_set(props, user_data, "added_users", clean_func=set)
set_if_set(props, user_data, "removed_users", clean_func=set)

if "name" in data and _rx_slug.search(data["name"]) is None:
raise ValueError("invalid name format. slug format required.")
if props.user_update_mode not in (None, Undefined, "add", "remove"):
raise ValueError("invalid user_update_mode")
if not props.user_uuids:
props.user_update_mode = None
if not data and props.user_update_mode is None:
if not data and not user_data and props.user_update_mode in (None, Undefined):
return cls(ok=False, msg="nothing to update", group=None)

async def _do_mutate() -> ModifyGroup:
async with graph_ctx.db.begin() as conn:
# TODO: refactor user addition/removal in groups as separate mutations
# (to apply since 21.09)
async with graph_ctx.db.begin_session() as db_session:
# Using `user_update_mode` and `user_uuids` is deprecated
if props.user_update_mode == "add":
values = [{"user_id": uuid, "group_id": gid} for uuid in props.user_uuids]
await conn.execute(
await db_session.execute(
sa.insert(association_groups_users).values(values),
)
elif props.user_update_mode == "remove":
await conn.execute(
await db_session.execute(
sa.delete(association_groups_users).where(
(association_groups_users.c.user_id.in_(props.user_uuids))
& (association_groups_users.c.group_id == gid),
),
)

added_users = user_data.get("added_users") or set()
removed_users = user_data.get("removed_users") or set()
if union := (added_users & removed_users):
raise ValueError(
"Should be no duplicate user id in `added_users` and `removed_users`."
f" (ids: {list(union)})"
)
if added_users:
values = [{"user_id": uuid, "group_id": gid} for uuid in added_users]
await db_session.execute(
sa.insert(association_groups_users).values(values),
)
if removed_users:
values = [{"user_id": uuid, "group_id": gid} for uuid in removed_users]
await db_session.execute(
sa.insert(association_groups_users).values(values),
)

if data:
result = await conn.execute(
result = await db_session.execute(
sa.update(groups).values(data).where(groups.c.id == gid).returning(groups),
)
if result.rowcount > 0:
Expand Down
111 changes: 59 additions & 52 deletions src/ai/backend/manager/models/resource_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from .keypair import keypairs
from .user import UserRole
from .utils import deprecation_reason_msg, description_msg

if TYPE_CHECKING:
from .gql import GraphQueryContext
Expand Down Expand Up @@ -58,34 +57,6 @@
)


def user_max_vfolder_count(required: bool = False):
return graphene.Int(
required=required,
description=description_msg("24.03.1", "Limitation of the number of user vfolders."),
)


def user_max_quota_scope_size(required: bool = False):
return BigInt(
required=required,
description=description_msg("24.03.1", "Limitation of the quota size of user vfolders."),
)


def project_max_vfolder_count(required: bool = False):
return graphene.Int(
required=required,
description=description_msg("24.03.1", "Limitation of the number of project vfolders."),
)


def project_max_quota_scope_size(required: bool = False):
return BigInt(
required=required,
description=description_msg("24.03.1", "Limitation of the quota size of project vfolders."),
)


keypair_resource_policies = sa.Table(
"keypair_resource_policies",
mapper_registry.metadata,
Expand Down Expand Up @@ -172,9 +143,9 @@ class KeyPairResourcePolicy(graphene.ObjectType):
idle_timeout = BigInt()
allowed_vfolder_hosts = graphene.JSONString()

max_vfolder_count = graphene.Int(deprecation_reason=deprecation_reason_msg("23.09.4"))
max_vfolder_size = BigInt(deprecation_reason=deprecation_reason_msg("23.09.4"))
max_quota_scope_size = BigInt(deprecation_reason=deprecation_reason_msg("23.09.4"))
max_vfolder_count = graphene.Int(deprecation_reason="Deprecated since 23.09.4")
max_vfolder_size = BigInt(deprecation_reason="Deprecated since 23.09.4")
max_quota_scope_size = BigInt(deprecation_reason="Deprecated since 23.09.4")

@classmethod
def from_row(
Expand Down Expand Up @@ -326,15 +297,15 @@ class CreateKeyPairResourcePolicyInput(graphene.InputObjectType):
allowed_vfolder_hosts = graphene.JSONString(required=False)
max_vfolder_count = graphene.Int(
required=False,
deprecation_reason=deprecation_reason_msg("23.09.4"),
deprecation_reason="Deprecated since 23.09.4",
)
max_vfolder_size = BigInt(
required=False,
deprecation_reason=deprecation_reason_msg("23.09.4"),
deprecation_reason="Deprecated since 23.09.4",
)
max_quota_scope_size = BigInt(
required=False,
deprecation_reason=deprecation_reason_msg("23.09.4"),
deprecation_reason="Deprecated since 23.09.4",
)


Expand All @@ -349,15 +320,15 @@ class ModifyKeyPairResourcePolicyInput(graphene.InputObjectType):
allowed_vfolder_hosts = graphene.JSONString(required=False)
max_vfolder_count = graphene.Int(
required=False,
deprecation_reason=deprecation_reason_msg("23.09.4"),
deprecation_reason="Deprecated since 23.09.4",
)
max_vfolder_size = BigInt(
required=False,
deprecation_reason=deprecation_reason_msg("23.09.4"),
deprecation_reason="Deprecated since 23.09.4",
)
max_quota_scope_size = BigInt(
required=False,
deprecation_reason=deprecation_reason_msg("23.09.4"),
deprecation_reason="Deprecated since 23.09.4",
)


Expand Down Expand Up @@ -471,9 +442,15 @@ class UserResourcePolicy(graphene.ObjectType):
id = graphene.ID(required=True)
name = graphene.String(required=True)
created_at = GQLDateTime(required=True)
max_vfolder_count = user_max_vfolder_count()
max_quota_scope_size = user_max_quota_scope_size()
max_vfolder_size = BigInt(deprecation_reason=deprecation_reason_msg("23.09.1"))
max_vfolder_count = graphene.Int(
required=False,
description="Added since 24.03.0. Limitation of the number of user vfolders.",
)
max_quota_scope_size = BigInt(
required=False,
description="Added since 24.03.0. Limitation of the quota size of user vfolders.",
)
max_vfolder_size = BigInt(deprecation_reason="Deprecated since 23.09.1")

@classmethod
def from_row(
Expand Down Expand Up @@ -545,13 +522,25 @@ async def batch_load_by_user(


class CreateUserResourcePolicyInput(graphene.InputObjectType):
max_vfolder_count = user_max_vfolder_count()
max_quota_scope_size = user_max_quota_scope_size()
max_vfolder_count = graphene.Int(
required=False,
description="Added since 24.03.0. Limitation of the number of user vfolders.",
)
max_quota_scope_size = BigInt(
required=False,
description="Added since 24.03.0. Limitation of the quota size of user vfolders.",
)


class ModifyUserResourcePolicyInput(graphene.InputObjectType):
max_vfolder_count = user_max_vfolder_count()
max_quota_scope_size = user_max_quota_scope_size()
max_vfolder_count = graphene.Int(
required=False,
description="Added since 24.03.0. Limitation of the number of user vfolders.",
)
max_quota_scope_size = BigInt(
required=False,
description="Added since 24.03.0. Limitation of the quota size of user vfolders.",
)


class CreateUserResourcePolicy(graphene.Mutation):
Expand Down Expand Up @@ -648,9 +637,15 @@ class ProjectResourcePolicy(graphene.ObjectType):
id = graphene.ID(required=True)
name = graphene.String(required=True)
created_at = GQLDateTime(required=True)
max_vfolder_count = project_max_vfolder_count()
max_quota_scope_size = project_max_quota_scope_size()
max_vfolder_size = BigInt(deprecation_reason=deprecation_reason_msg("23.09.1"))
max_vfolder_count = graphene.Int(
required=False,
description="Added since 24.03.0. Limitation of the number of project vfolders.",
)
max_quota_scope_size = BigInt(
required=False,
description="Added since 24.03.0. Limitation of the quota size of project vfolders.",
)
max_vfolder_size = BigInt(deprecation_reason="Deprecated since 23.09.1")

@classmethod
def from_row(
Expand Down Expand Up @@ -723,13 +718,25 @@ async def batch_load_by_project(


class CreateProjectResourcePolicyInput(graphene.InputObjectType):
max_vfolder_count = project_max_vfolder_count()
max_quota_scope_size = project_max_quota_scope_size()
max_vfolder_count = graphene.Int(
required=False,
description="Added since 24.03.0. Limitation of the number of project vfolders.",
)
max_quota_scope_size = BigInt(
required=False,
description="Added since 24.03.0. Limitation of the quota size of project vfolders.",
)


class ModifyProjectResourcePolicyInput(graphene.InputObjectType):
max_vfolder_count = project_max_vfolder_count()
max_quota_scope_size = project_max_quota_scope_size()
max_vfolder_count = graphene.Int(
required=False,
description="Added since 24.03.0. Limitation of the number of project vfolders.",
)
max_quota_scope_size = BigInt(
required=False,
description="Added since 24.03.0. Limitation of the quota size of project vfolders.",
)


class CreateProjectResourcePolicy(graphene.Mutation):
Expand Down
5 changes: 2 additions & 3 deletions src/ai/backend/manager/models/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ..api.exceptions import InvalidAPIParameters, VFolderOperationFailed
from ..exceptions import InvalidArgument
from .base import Item, PaginatedList
from .utils import description_msg

if TYPE_CHECKING:
from .gql import GraphQueryContext
Expand Down Expand Up @@ -209,9 +208,9 @@ class Meta:
performance_metric = graphene.JSONString()
usage = graphene.JSONString()
proxy = graphene.String(
description=description_msg("24.03.0", "Name of the proxy which this volume belongs to.")
description="Added since 24.03.0. Name of the proxy which this volume belongs to."
)
name = graphene.String(description=description_msg("24.03.0", "Name of the storage."))
name = graphene.String(description="Added since 24.03.0. Name of the storage.")

async def resolve_hardware_metadata(self, info: graphene.ResolveInfo) -> HardwareMetadata:
ctx: GraphQueryContext = info.context
Expand Down
14 changes: 0 additions & 14 deletions src/ai/backend/manager/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,3 @@ def agg_to_array(column: sa.Column) -> sa.sql.functions.Function:

def is_db_retry_error(e: Exception) -> bool:
return isinstance(e, DBAPIError) and getattr(e.orig, "pgcode", None) == "40001"


def description_msg(version: str, detail: str | None = None) -> str:
val = f"Added since {version}."
if detail:
val = f"{val} {detail}"
return val


def deprecation_reason_msg(version: str, detail: str | None = None) -> str:
val = f"Deprecated since {version}."
if detail:
val = f"{val} {detail}"
return val
Loading