Skip to content

Commit

Permalink
fix: Adapt the client SDK's functional wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Nov 3, 2023
1 parent 1cdb02c commit a5d4bab
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 190 deletions.
47 changes: 24 additions & 23 deletions src/ai/backend/client/func/domain.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import textwrap
from typing import Iterable, Sequence

from ai.backend.client.output.fields import domain_fields
from ai.backend.client.output.types import FieldSpec
from typing import Any, Iterable, Sequence

from ..output.fields import domain_fields
from ..output.types import FieldSpec
from ..session import api_session
from ..types import GraphQLInputVars, Undefined, set_if_set, undefined
from ..types import Undefined, set_if_set, undefined
from .base import BaseFunction, api_function, resolve_fields

__all__ = ("Domain",)
Expand Down Expand Up @@ -114,17 +113,18 @@ async def create(
""")
resolved_fields = resolve_fields(fields, domain_fields, (domain_fields["name"],))
query = query.replace("$fields", " ".join(resolved_fields))
variables: GraphQLInputVars = {
inputs = {
"description": description,
"is_active": is_active,
}
set_if_set(inputs, "total_resource_slots", total_resource_slots)
set_if_set(inputs, "allowed_vfolder_hosts", allowed_vfolder_hosts)
set_if_set(inputs, "allowed_docker_registries", allowed_docker_registries)
set_if_set(inputs, "integration_id", integration_id)
variables = {
"name": name,
"input": {
"description": description,
"is_active": is_active,
},
"input": inputs,
}
set_if_set(variables["input"], "total_resource_slots", total_resource_slots)
set_if_set(variables["input"], "allowed_vfolder_hosts", allowed_vfolder_hosts)
set_if_set(variables["input"], "allowed_docker_registries", allowed_docker_registries)
set_if_set(variables["input"], "integration_id", integration_id)
data = await api_session.get().Admin._query(query, variables)
return data["create_domain"]

Expand Down Expand Up @@ -153,17 +153,18 @@ async def update(
}
}
""")
variables: GraphQLInputVars = {
inputs: dict[str, Any] = {}
set_if_set(inputs, "name", new_name)
set_if_set(inputs, "description", description)
set_if_set(inputs, "is_active", is_active)
set_if_set(inputs, "total_resource_slots", total_resource_slots)
set_if_set(inputs, "allowed_vfolder_hosts", allowed_vfolder_hosts)
set_if_set(inputs, "allowed_docker_registries", allowed_docker_registries)
set_if_set(inputs, "integration_id", integration_id)
variables = {
"name": name,
"input": {},
"input": inputs,
}
set_if_set(variables["input"], "name", new_name)
set_if_set(variables["input"], "description", description)
set_if_set(variables["input"], "is_active", is_active)
set_if_set(variables["input"], "total_resource_slots", total_resource_slots)
set_if_set(variables["input"], "allowed_vfolder_hosts", allowed_vfolder_hosts)
set_if_set(variables["input"], "allowed_docker_registries", allowed_docker_registries)
set_if_set(variables["input"], "integration_id", integration_id)
data = await api_session.get().Admin._query(query, variables)
return data["modify_domain"]

Expand Down
39 changes: 21 additions & 18 deletions src/ai/backend/client/func/group.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import textwrap
from typing import Iterable, Optional, Sequence
from typing import Any, Iterable, Optional, Sequence

from ai.backend.client.output.fields import group_fields
from ai.backend.client.output.types import FieldSpec

from ..session import api_session
from ..types import Undefined, set_if_set, undefined
from .base import BaseFunction, api_function, resolve_fields

__all__ = ("Group",)
Expand Down Expand Up @@ -128,12 +129,13 @@ async def create(
cls,
domain_name: str,
name: str,
*,
description: str = "",
is_active: bool = True,
total_resource_slots: Optional[str] = None,
allowed_vfolder_hosts: Optional[str] = None,
integration_id: str = None,
fields: Iterable[FieldSpec | str] = None,
integration_id: Optional[str] = None,
fields: Iterable[FieldSpec | str] | None = None,
) -> dict:
"""
Creates a new group with the given options.
Expand Down Expand Up @@ -171,13 +173,14 @@ async def create(
async def update(
cls,
gid: str,
name: str = None,
description: str = None,
is_active: bool = None,
total_resource_slots: Optional[str] = None,
allowed_vfolder_hosts: Optional[str] = None,
integration_id: str = None,
fields: Iterable[FieldSpec | str] = None,
*,
name: str | Undefined = undefined,
description: str | Undefined = undefined,
is_active: bool | Undefined = undefined,
total_resource_slots: Optional[str] | Undefined = undefined,
allowed_vfolder_hosts: Optional[str] | Undefined = undefined,
integration_id: str | Undefined = undefined,
fields: Iterable[FieldSpec | str] | None = None,
) -> dict:
"""
Update existing group.
Expand All @@ -190,16 +193,16 @@ async def update(
}
}
""")
inputs: dict[str, Any] = {}
set_if_set(inputs, "name", name)
set_if_set(inputs, "description", description)
set_if_set(inputs, "is_active", is_active)
set_if_set(inputs, "total_resource_slots", total_resource_slots)
set_if_set(inputs, "allowed_vfolder_hosts", allowed_vfolder_hosts)
set_if_set(inputs, "integration_id", integration_id)
variables = {
"gid": gid,
"input": {
"name": name,
"description": description,
"is_active": is_active,
"total_resource_slots": total_resource_slots,
"allowed_vfolder_hosts": allowed_vfolder_hosts,
"integration_id": integration_id,
},
"input": inputs,
}
data = await api_session.get().Admin._query(query, variables)
return data["modify_group"]
Expand Down
65 changes: 33 additions & 32 deletions src/ai/backend/client/func/keypair.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Sequence, Union

from ai.backend.client.output.fields import keypair_fields
from ai.backend.client.output.types import FieldSpec, PaginatedResult
from ai.backend.client.pagination import fetch_paginated_result
from ai.backend.client.session import api_session
from typing import Any, Dict, Sequence

from ..output.fields import keypair_fields
from ..output.types import FieldSpec, PaginatedResult
from ..pagination import fetch_paginated_result
from ..session import api_session
from ..types import Undefined, set_if_set, undefined
from .base import BaseFunction, api_function

__all__ = ("KeyPair",)
Expand Down Expand Up @@ -44,11 +44,11 @@ def __init__(self, access_key: str):
@classmethod
async def create(
cls,
user_id: Union[int, str],
user_id: int | str,
is_active: bool = True,
is_admin: bool = False,
resource_policy: str = None,
rate_limit: int = None,
resource_policy: str | Undefined = undefined,
rate_limit: int | Undefined = undefined,
fields: Sequence[FieldSpec] = _default_result_fields,
) -> dict:
"""
Expand All @@ -64,14 +64,15 @@ async def create(
"}"
)
q = q.replace("$fields", " ".join(f.field_ref for f in fields))
inputs = {
"is_active": is_active,
"is_admin": is_admin,
}
set_if_set(inputs, "resource_policy", resource_policy)
set_if_set(inputs, "rate_limit", rate_limit)
variables = {
"user_id": user_id,
"input": {
"is_active": is_active,
"is_admin": is_admin,
"resource_policy": resource_policy,
"rate_limit": rate_limit,
},
"input": inputs,
}
data = await api_session.get().Admin._query(q, variables)
return data["create_keypair"]
Expand All @@ -81,10 +82,10 @@ async def create(
async def update(
cls,
access_key: str,
is_active: bool = None,
is_admin: bool = None,
resource_policy: str = None,
rate_limit: int = None,
is_active: bool | Undefined = undefined,
is_admin: bool | Undefined = undefined,
resource_policy: str | Undefined = undefined,
rate_limit: int | Undefined = undefined,
) -> dict:
"""
Creates a new keypair with the given options.
Expand All @@ -94,14 +95,14 @@ async def update(
"mutation($access_key: String!, $input: ModifyKeyPairInput!) {"
+ " modify_keypair(access_key: $access_key, props: $input) { ok msg }}"
)
inputs: dict[str, Any] = {}
set_if_set(inputs, "is_active", is_active)
set_if_set(inputs, "is_admin", is_admin)
set_if_set(inputs, "resource_policy", resource_policy)
set_if_set(inputs, "rate_limit", rate_limit)
variables = {
"access_key": access_key,
"input": {
"is_active": is_active,
"is_admin": is_admin,
"resource_policy": resource_policy,
"rate_limit": rate_limit,
},
"input": inputs,
}
data = await api_session.get().Admin._query(q, variables)
return data["modify_keypair"]
Expand Down Expand Up @@ -129,8 +130,8 @@ async def delete(cls, access_key: str):
@classmethod
async def list(
cls,
user_id: Union[int, str] = None,
is_active: bool = None,
user_id: int | str | None = None,
is_active: bool | None = None,
fields: Sequence[FieldSpec] = _default_list_fields,
) -> Sequence[dict]:
"""
Expand Down Expand Up @@ -158,15 +159,15 @@ async def list(
@classmethod
async def paginated_list(
cls,
is_active: bool = None,
domain_name: str = None,
is_active: bool | None = None,
domain_name: str | None = None,
*,
user_id: str = None,
user_id: str | None = None,
fields: Sequence[FieldSpec] = _default_list_fields,
page_offset: int = 0,
page_size: int = 20,
filter: str = None,
order: str = None,
filter: str | None = None,
order: str | None = None,
) -> PaginatedResult[dict]:
"""
Lists the keypairs.
Expand Down
Loading

0 comments on commit a5d4bab

Please sign in to comment.