Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-530): Update SDK and CLI to follow-up per-user uid/gid #3361

Open
wants to merge 3 commits into
base: topic/01-02-feat_add_per-user_uid_gid_apis
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3361.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update SDK and CLI to support per-user UID/GID configuration
23 changes: 22 additions & 1 deletion src/ai/backend/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,31 @@ def __init__(self, value: Any) -> None: ...
TScalar = TypeVar("TScalar", bound=SingleValueConstructorType | click.ParamType)


class CommaSeparatedListType(click.ParamType, Generic[TScalar]):
name = "List Expression"

def __init__(self, type_: Optional[type[TScalar]] = None) -> None:
super().__init__()
self.type_ = type_ if type_ is not None else str

def convert(self, arg, param, ctx):
try:
match arg:
case int():
return arg
case str():
return [self.type_(elem) for elem in arg.split(",")]
except ValueError as e:
self.fail(repr(e), param, ctx)


T = TypeVar("T")


class OptionalType(click.ParamType, Generic[TScalar]):
name = "Optional Type Wrapper"

def __init__(self, type_: type[TScalar] | click.ParamType) -> None:
def __init__(self, type_: type[TScalar] | type[click.ParamType] | click.ParamType) -> None:
super().__init__()
self.type_ = type_

Expand Down
114 changes: 112 additions & 2 deletions src/ai/backend/client/cli/admin/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
import uuid
from typing import Iterable, Sequence
from collections.abc import Iterable, Sequence

import click

Expand All @@ -14,7 +14,7 @@
from ai.backend.client.session import Session

from ..extensions import pass_ctx_obj
from ..pretty import print_info
from ..pretty import print_fail, print_info
from ..types import CLIContext
from . import admin

Expand Down Expand Up @@ -228,6 +228,33 @@ def list(ctx: CLIContext, status, group, filter_, order, offset, limit) -> None:
"Note that this feature does not automatically install sudo for the session."
),
)
@click.option(
"--cuid",
"--container-uid",
"container_uid",
type=int,
default=None,
help="The user ID (UID) that will be assigned to all processes running inside containers created by this user.",
)
@click.option(
"--cgid",
"--container-main-gid",
"container_main_gid",
type=int,
default=None,
help="The primary group ID (GID) that will be assigned to all processes running inside containers created by this user.",
)
@click.option(
"--cgids",
"--container-gids",
"container_gids",
type=CommaSeparatedListType(int),
default=None,
help=(
"Supplementary group IDs that will be assigned to all processes running inside containers created by this user. "
"(e.g., --cgids 1001,1002,1003)"
),
)
@click.option(
"-g",
"--group",
Expand All @@ -250,6 +277,9 @@ def add(
allowed_ip: str | None,
description: str,
sudo_session_enabled: bool,
container_uid: int | None,
container_main_gid: int | None,
container_gids: Iterable[int] | None,
groups: Iterable[str],
):
"""
Expand Down Expand Up @@ -293,6 +323,9 @@ def add(
group_ids=group_ids,
description=description,
sudo_session_enabled=sudo_session_enabled,
container_uid=container_uid,
container_main_gid=container_main_gid,
container_gids=container_gids,
fields=(
user_fields["domain_name"],
user_fields["email"],
Expand Down Expand Up @@ -410,6 +443,57 @@ def add(
default=undefined,
help="Set main access key which works as default.",
)
@click.option(
"--cuid",
"--container-uid",
"container_uid",
type=OptionalType(int),
default=undefined,
help="The user ID (UID) that will be assigned to all processes running inside containers created by this user.",
)
@click.option(
"--unset-cuid",
"--unset-container-uid",
"unset_container_uid",
is_flag=True,
default=False,
help="Unset the user's container UID.",
)
@click.option(
"--cgid",
"--container-main-gid",
"container_main_gid",
type=OptionalType(int),
default=undefined,
help="The primary group ID (GID) that will be assigned to all processes running inside containers created by this user.",
)
@click.option(
"--unset-cgid",
"--unset-container-main-gid",
"unset_container_main_gid",
is_flag=True,
default=False,
help="Unset the user's container primary GID.",
)
@click.option(
"--cgids",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that abbreviations like cuid and cguid might overlap with other topics.
How about avoiding using abbreviations for now?

"--container-gids",
"container_gids",
type=OptionalType(CommaSeparatedListType(int)),
default=undefined,
help=(
"Supplementary group IDs that will be assigned to all processes running inside containers created by this user. "
"(e.g., --cgids 1001,1002,1003)"
),
)
@click.option(
"--unset-cgids",
"--unset-container-gids",
"unset_container_gids",
is_flag=True,
default=False,
help="Unset the user's container supplementary group IDs.",
)
def update(
ctx: CLIContext,
email: str,
Expand All @@ -424,13 +508,36 @@ def update(
description: str | Undefined,
sudo_session_enabled: bool | Undefined,
main_access_key: str | Undefined,
container_uid: int | Undefined,
unset_container_uid: bool,
container_main_gid: int | Undefined,
unset_container_main_gid: bool,
container_gids: Iterable[int] | Undefined,
unset_container_gids: bool,
):
"""
Update an existing user.

\b
EMAIL: Email of user to update.
"""

def validate_input[T_Input](value: T_Input, null_flag: bool, field_name: str) -> T_Input | None:
_val: T_Input | None = value
if null_flag:
if value is not undefined:
print_fail(
f"Cannot set both --{field_name} and --unset-{field_name}. "
f"Use --{field_name} to set a value or --unset-{field_name} to set null, but not both."
)
sys.exit(ExitCode.FAILURE)
_val = None
return _val

cuid = validate_input(container_uid, unset_container_uid, "cuid")
cgid = validate_input(container_main_gid, unset_container_main_gid, "cgid")
cgids = validate_input(container_gids, unset_container_gids, "cgids")

with Session() as session:
try:
data = session.User.update(
Expand All @@ -446,6 +553,9 @@ def update(
description=description,
sudo_session_enabled=sudo_session_enabled,
main_access_key=main_access_key,
container_uid=cuid,
container_main_gid=cgid,
container_gids=cgids,
)
except Exception as e:
ctx.output.print_mutation_error(
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/client/cli/session/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from .args import click_start_option

tabulate_mod.PRESERVE_WHITESPACE = True
range_expr = RangeExprOptionType()
list_expr = CommaSeparatedListType()
range_expr: click.ParamType = RangeExprOptionType()
list_expr: click.ParamType = CommaSeparatedListType()


async def exec_loop(
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)
from .ssh import container_ssh_ctx

list_expr = CommaSeparatedListType()
list_expr: click.ParamType = CommaSeparatedListType()


@main.group()
Expand Down
12 changes: 12 additions & 0 deletions src/ai/backend/client/func/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ async def create(
totp_activated: bool = False,
group_ids: Iterable[str] | Undefined = undefined,
sudo_session_enabled: bool = False,
container_uid: int | Undefined = undefined,
container_main_gid: int | Undefined = undefined,
container_gids: Iterable[int] | Undefined = undefined,
fields: Iterable[FieldSpec | str] | None = None,
) -> dict:
"""
Expand Down Expand Up @@ -300,6 +303,9 @@ async def create(
set_if_set(inputs, "full_name", full_name)
set_if_set(inputs, "allowed_client_ip", allowed_client_ip)
set_if_set(inputs, "group_ids", group_ids)
set_if_set(inputs, "container_uid", container_uid)
set_if_set(inputs, "container_main_gid", container_main_gid)
set_if_set(inputs, "container_gids", container_gids)
variables = {
"email": email,
"input": inputs,
Expand All @@ -326,6 +332,9 @@ async def update(
group_ids: Iterable[str] | Undefined = undefined,
sudo_session_enabled: bool | Undefined = undefined,
main_access_key: str | Undefined = undefined,
container_uid: int | None | Undefined = undefined,
container_main_gid: int | None | Undefined = undefined,
container_gids: Iterable[int] | None | Undefined = undefined,
fields: Iterable[FieldSpec | str] | None = None,
) -> dict:
"""
Expand Down Expand Up @@ -353,6 +362,9 @@ async def update(
set_if_set(inputs, "group_ids", group_ids)
set_if_set(inputs, "sudo_session_enabled", sudo_session_enabled)
set_if_set(inputs, "main_access_key", main_access_key)
set_if_set(inputs, "container_uid", container_uid)
set_if_set(inputs, "container_main_gid", container_main_gid)
set_if_set(inputs, "container_gids", container_gids)
variables = {
"email": email,
"input": inputs,
Expand Down
Loading