Skip to content

Commit

Permalink
fix: Correct handling of undefined values in the manager side
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Nov 3, 2023
1 parent 0504645 commit 1cdb02c
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 14 deletions.
5 changes: 5 additions & 0 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,11 @@ def set_if_set(
clean_func=None,
target_key: Optional[str] = None,
) -> None:
"""
Set the target dict with only non-undefined keys and their values
from a Graphene's input object.
(server-side function)
"""
v = getattr(src, name)
# NOTE: unset optional fields are passed as graphql.Undefined.
if v is not Undefined:
Expand Down
16 changes: 11 additions & 5 deletions src/ai/backend/manager/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,12 @@ async def get_groups_for_user(


class GroupInput(graphene.InputObjectType):
description = graphene.String(required=False)
description = graphene.String(required=False, default="")
is_active = graphene.Boolean(required=False, default=True)
domain_name = graphene.String(required=True)
total_resource_slots = graphene.JSONString(required=False)
allowed_vfolder_hosts = graphene.JSONString(required=False)
integration_id = graphene.String(required=False)
integration_id = graphene.String(required=False, default="")
resource_policy = graphene.String(required=False, default="default")


Expand Down Expand Up @@ -430,11 +430,17 @@ async def mutate(
"description": props.description,
"is_active": props.is_active,
"domain_name": props.domain_name,
"total_resource_slots": ResourceSlot.from_user_input(props.total_resource_slots, None),
"allowed_vfolder_hosts": props.allowed_vfolder_hosts,
"integration_id": props.integration_id,
"resource_policy": props.resource_policy or "default",
"resource_policy": props.resource_policy,
}
# set_if_set() applies to optional without defaults
set_if_set(
props,
data,
"total_resource_slots",
clean_func=lambda v: ResourceSlot.from_user_input(v, None),
)
set_if_set(props, data, "allowed_vfolder_hosts")
insert_query = sa.insert(groups).values(data)
return await simple_db_mutate_returning_item(cls, graph_ctx, insert_query, item_cls=Group)

Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ async def mutate(
"is_admin": props.is_admin,
"resource_policy": props.resource_policy,
"rate_limit": props.rate_limit,
# props.concurrency_limit is always ignored
},
)
insert_query = sa.insert(keypairs).values(
Expand Down Expand Up @@ -606,6 +607,7 @@ async def mutate(
set_if_set(props, data, "is_admin")
set_if_set(props, data, "resource_policy")
set_if_set(props, data, "rate_limit")
# props.concurrency_limit is always ignored
update_query = sa.update(keypairs).values(data).where(keypairs.c.access_key == access_key)
return await simple_db_mutate(cls, ctx, update_query)

Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/manager/models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ class CreateScalingGroupInput(graphene.InputObjectType):
description = graphene.String(required=False, default="")
is_active = graphene.Boolean(required=False, default=True)
is_public = graphene.Boolean(required=False, default=True)
wsproxy_addr = graphene.String(required=False)
wsproxy_api_token = graphene.String(required=False)
wsproxy_addr = graphene.String(required=False, default=None)
wsproxy_api_token = graphene.String(required=False, default=None)
driver = graphene.String(required=True)
driver_opts = graphene.JSONString(required=False, default={})
scheduler = graphene.String(required=True)
Expand Down Expand Up @@ -570,9 +570,9 @@ async def mutate(
set_if_set(props, data, "description")
set_if_set(props, data, "is_active")
set_if_set(props, data, "is_public")
set_if_set(props, data, "driver")
set_if_set(props, data, "wsproxy_addr")
set_if_set(props, data, "wsproxy_api_token")
set_if_set(props, data, "driver")
set_if_set(props, data, "driver_opts")
set_if_set(props, data, "scheduler")
set_if_set(
Expand Down
10 changes: 5 additions & 5 deletions src/ai/backend/manager/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ class UserInput(graphene.InputObjectType):
domain_name = graphene.String(required=True, default="default")
role = graphene.String(required=False, default=UserRole.USER)
group_ids = graphene.List(lambda: graphene.String, required=False)
allowed_client_ip = graphene.List(lambda: graphene.String, required=False)
allowed_client_ip = graphene.List(lambda: graphene.String, required=False, defualt=None)
totp_activated = graphene.Boolean(required=False, default=False)
resource_policy = graphene.String(required=False, default="default")
sudo_session_enabled = graphene.Boolean(required=False, default=False)
Expand Down Expand Up @@ -585,8 +585,8 @@ async def mutate(
"role": UserRole(props.role),
"allowed_client_ip": props.allowed_client_ip,
"totp_activated": props.totp_activated,
"resource_policy": props.resource_policy or "default",
"sudo_session_enabled": props.sudo_session_enabled or False,
"resource_policy": props.resource_policy,
"sudo_session_enabled": props.sudo_session_enabled,
}
user_insert_query = sa.insert(users).values(user_data)

Expand Down Expand Up @@ -614,9 +614,9 @@ async def _post_func(conn: SAConnection, result: Result) -> Row:
await conn.execute(kp_insert_query)

# Add user to groups if group_ids parameter is provided.
from .group import association_groups_users, groups

if props.group_ids:
from .group import association_groups_users, groups

query = (
sa.select([groups.c.id])
.select_from(groups)
Expand Down
10 changes: 9 additions & 1 deletion src/ai/backend/manager/models/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dateutil.parser import parse as dtparse
from dateutil.tz import tzutc
from graphene.types.datetime import DateTime as GQLDateTime
from graphql import Undefined
from sqlalchemy.dialects import postgresql as pgsql
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
Expand Down Expand Up @@ -1721,7 +1722,14 @@ async def mutate(
graph_ctx: GraphQueryContext = info.context
async with graph_ctx.db.begin_readonly_session() as sess:
await ensure_quota_scope_accessible_by_user(sess, qsid, graph_ctx.user)

if props.hard_limit_bytes is Undefined:
# Do nothing but just return the quota scope object.
return cls(
QuotaScope(
quota_scope_id=quota_scope_id,
storage_host_name=storage_host_name,
)
)
max_vfolder_size = props.hard_limit_bytes
proxy_name, volume_name = graph_ctx.storage_manager.split_host(storage_host_name)
request_body = {
Expand Down

0 comments on commit 1cdb02c

Please sign in to comment.