Skip to content

Commit

Permalink
feat: Update SDK and CLI to follow-up per-user uid/gid
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Jan 3, 2025
1 parent 553894d commit 10fc52d
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 18 deletions.
33 changes: 20 additions & 13 deletions src/ai/backend/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,40 +169,47 @@ def convert(self, arg, param, ctx):
self.fail(str(e), param, ctx)


class CommaSeparatedListType(click.ParamType):
class SingleValueConstructorType(Protocol):
def __init__(self, value: Any) -> None: ...


TScalar = TypeVar("TScalar", bound=SingleValueConstructorType)


class CommaSeparatedListType(click.ParamType, Generic[TScalar]):
name = "List Expression"

def __init__(self, type_: Optional[type[TScalar]] = None) -> None:
super().__init__()
self.type_ = type_ if type_ is not None else str

def convert(self, arg, param, ctx):
try:
if isinstance(arg, int):
return arg
elif isinstance(arg, str):
return arg.split(",")
match arg:
case int():
return arg
case str():
return [self.type_(elem) for elem in arg.split(",")]
except ValueError as e:
self.fail(repr(e), param, ctx)


T = TypeVar("T")


class SingleValueConstructorType(Protocol):
def __init__(self, value: Any) -> None: ...


TScalar = TypeVar("TScalar", bound=SingleValueConstructorType)


class OptionalType(click.ParamType, Generic[TScalar]):
name = "Optional Type Wrapper"

def __init__(self, type_: type[TScalar] | type[click.ParamType]) -> None:
def __init__(self, type_: type[TScalar] | type[click.ParamType] | click.ParamType) -> None:
super().__init__()
self.type_ = type_

def convert(self, value: Any, param, ctx) -> TScalar | Undefined:
try:
if value is undefined:
return undefined
if isinstance(self.type_, click.ParamType):
return self.type_(value)
if issubclass(self.type_, click.ParamType):
return self.type_()(value)
return self.type_(value)
Expand Down
116 changes: 114 additions & 2 deletions src/ai/backend/client/cli/admin/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
import uuid
from typing import Iterable, Sequence
from collections.abc import Iterable, Sequence

import click

Expand All @@ -14,7 +14,7 @@
from ai.backend.client.session import Session

from ..extensions import pass_ctx_obj
from ..pretty import print_info
from ..pretty import print_fail, print_info
from ..types import CLIContext
from . import admin

Expand Down Expand Up @@ -228,6 +228,33 @@ def list(ctx: CLIContext, status, group, filter_, order, offset, limit) -> None:
"Note that this feature does not automatically install sudo for the session."
),
)
@click.option(
"--cuid",
"--container-uid",
"container_uid",
type=int,
default=None,
help="The user ID (UID) that will be assigned to all processes running inside containers created by this user.",
)
@click.option(
"--cgid",
"--container-main-gid",
"container_main_gid",
type=int,
default=None,
help="The primary group ID (GID) that will be assigned to all processes running inside containers created by this user.",
)
@click.option(
"--cgids",
"--container-supplementary-gids",
"container_supplementary_gids",
type=CommaSeparatedListType(int),
default=None,
help=(
"Supplementary group IDs that will be assigned to all processes running inside containers created by this user. "
"(e.g., --cgids 1001,1002,1003)"
),
)
@click.option(
"-g",
"--group",
Expand All @@ -250,6 +277,9 @@ def add(
allowed_ip: str | None,
description: str,
sudo_session_enabled: bool,
container_uid: int | None,
container_main_gid: int | None,
container_supplementary_gids: Iterable[int] | None,
groups: Iterable[str],
):
"""
Expand Down Expand Up @@ -293,6 +323,9 @@ def add(
group_ids=group_ids,
description=description,
sudo_session_enabled=sudo_session_enabled,
container_uid=container_uid,
container_main_gid=container_main_gid,
container_supplementary_gids=container_supplementary_gids,
fields=(
user_fields["domain_name"],
user_fields["email"],
Expand Down Expand Up @@ -410,6 +443,57 @@ def add(
default=undefined,
help="Set main access key which works as default.",
)
@click.option(
"--cuid",
"--container-uid",
"container_uid",
type=OptionalType(int),
default=undefined,
help="The user ID (UID) that will be assigned to all processes running inside containers created by this user.",
)
@click.option(
"--unset-cuid",
"--unset-container-uid",
"unset_container_uid",
is_flag=True,
default=False,
help="Unset the user's container UID.",
)
@click.option(
"--cgid",
"--container-main-gid",
"container_main_gid",
type=OptionalType(int),
default=undefined,
help="The primary group ID (GID) that will be assigned to all processes running inside containers created by this user.",
)
@click.option(
"--unset-cgid",
"--unset-container-main-gid",
"unset_container_main_gid",
is_flag=True,
default=False,
help="Unset the user's container primary GID.",
)
@click.option(
"--cgids",
"--container-supplementary-gids",
"container_supplementary_gids",
type=OptionalType(CommaSeparatedListType(int)),
default=undefined,
help=(
"Supplementary group IDs that will be assigned to all processes running inside containers created by this user. "
"(e.g., --cgids 1001,1002,1003)"
),
)
@click.option(
"--unset-cgids",
"--unset-container-supplementary-gids",
"unset_container_supplementary_gids",
is_flag=True,
default=False,
help="Unset the user's container supplementary group IDs.",
)
def update(
ctx: CLIContext,
email: str,
Expand All @@ -424,13 +508,38 @@ def update(
description: str | Undefined,
sudo_session_enabled: bool | Undefined,
main_access_key: str | Undefined,
container_uid: int | Undefined,
unset_container_uid: bool,
container_main_gid: int | Undefined,
unset_container_main_gid: bool,
container_supplementary_gids: Iterable[int] | Undefined,
unset_container_supplementary_gids: bool,
):
"""
Update an existing user.
\b
EMAIL: Email of user to update.
"""

def validate_input[T_Input](value: T_Input, null_flag: bool, field_name: str) -> T_Input | None:
_val: T_Input | None = value
if null_flag:
if value is not undefined:
print_fail(
f"Cannot set both --{field_name} and --unset-{field_name}. "
f"Use --{field_name} to set a value or --unset-{field_name} to set null, but not both."
)
sys.exit(ExitCode.FAILURE)
_val = None
return _val

cuid = validate_input(container_uid, unset_container_uid, "cuid")
cgid = validate_input(container_main_gid, unset_container_main_gid, "cgid")
cgids = validate_input(
container_supplementary_gids, unset_container_supplementary_gids, "cgids"
)

with Session() as session:
try:
data = session.User.update(
Expand All @@ -446,6 +555,9 @@ def update(
description=description,
sudo_session_enabled=sudo_session_enabled,
main_access_key=main_access_key,
container_uid=cuid,
container_main_gid=cgid,
container_supplementary_gids=cgids,
)
except Exception as e:
ctx.output.print_mutation_error(
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/client/cli/session/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from .args import click_start_option

tabulate_mod.PRESERVE_WHITESPACE = True
range_expr = RangeExprOptionType()
list_expr = CommaSeparatedListType()
range_expr: click.ParamType = RangeExprOptionType()
list_expr: click.ParamType = CommaSeparatedListType()


async def exec_loop(
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)
from .ssh import container_ssh_ctx

list_expr = CommaSeparatedListType()
list_expr: click.ParamType = CommaSeparatedListType()


@main.group()
Expand Down
12 changes: 12 additions & 0 deletions src/ai/backend/client/func/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ async def create(
totp_activated: bool = False,
group_ids: Iterable[str] | Undefined = undefined,
sudo_session_enabled: bool = False,
container_uid: int | Undefined = undefined,
container_main_gid: int | Undefined = undefined,
container_supplementary_gids: Iterable[int] | Undefined = undefined,
fields: Iterable[FieldSpec | str] | None = None,
) -> dict:
"""
Expand Down Expand Up @@ -300,6 +303,9 @@ async def create(
set_if_set(inputs, "full_name", full_name)
set_if_set(inputs, "allowed_client_ip", allowed_client_ip)
set_if_set(inputs, "group_ids", group_ids)
set_if_set(inputs, "container_uid", container_uid)
set_if_set(inputs, "container_main_gid", container_main_gid)
set_if_set(inputs, "container_supplementary_gids", container_supplementary_gids)
variables = {
"email": email,
"input": inputs,
Expand All @@ -326,6 +332,9 @@ async def update(
group_ids: Iterable[str] | Undefined = undefined,
sudo_session_enabled: bool | Undefined = undefined,
main_access_key: str | Undefined = undefined,
container_uid: int | None | Undefined = undefined,
container_main_gid: int | None | Undefined = undefined,
container_supplementary_gids: Iterable[int] | None | Undefined = undefined,
fields: Iterable[FieldSpec | str] | None = None,
) -> dict:
"""
Expand Down Expand Up @@ -353,6 +362,9 @@ async def update(
set_if_set(inputs, "group_ids", group_ids)
set_if_set(inputs, "sudo_session_enabled", sudo_session_enabled)
set_if_set(inputs, "main_access_key", main_access_key)
set_if_set(inputs, "container_uid", container_uid)
set_if_set(inputs, "container_main_gid", container_main_gid)
set_if_set(inputs, "container_supplementary_gids", container_supplementary_gids)
variables = {
"email": email,
"input": inputs,
Expand Down

0 comments on commit 10fc52d

Please sign in to comment.