Skip to content

Commit

Permalink
use list rather than set
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Jan 6, 2025
1 parent a36514a commit b630681
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
13 changes: 8 additions & 5 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,8 @@ def get_overriding_uid(self) -> Optional[int]:
def get_overriding_gid(self) -> Optional[int]:
return None

def get_supplementary_gids(self) -> Optional[list[int]]:
return None
def get_supplementary_gids(self) -> set[int]:
return set()


KernelCreationContextType = TypeVar(
Expand Down Expand Up @@ -1881,14 +1881,17 @@ async def create_kernel(
uid = self.local_config["container"]["kernel-uid"]
environ["LOCAL_USER_ID"] = str(uid)

sgids = set(ctx.get_supplementary_gids() or [])
kernel_gid: int = self.local_config["container"]["kernel-gid"]
if (ogid := ctx.get_overriding_gid()) is not None:
environ["LOCAL_GROUP_ID"] = str(ogid)
if KernelFeatures.UID_MATCH in ctx.kernel_features:
sgids.add(kernel_gid)
else:
if KernelFeatures.UID_MATCH in ctx.kernel_features:
gid = self.local_config["container"]["kernel-gid"]
environ["LOCAL_GROUP_ID"] = str(gid)
environ["LOCAL_GROUP_ID"] = str(kernel_gid)

update_additional_gids(environ, ctx.get_supplementary_gids() or [])
update_additional_gids(environ, sgids)
environ.update(
await ctx.get_extra_envs(),
)
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def __init__(
self.work_dir = scratch_dir / "work"
self.uid = kernel_config["uid"]
self.main_gid = kernel_config["main_gid"]
self.supplementary_gids = kernel_config["supplementary_gids"]
self.supplementary_gids = set(kernel_config["supplementary_gids"])

self.port_pool = port_pool
self.agent_sockpath = agent_sockpath
Expand All @@ -261,7 +261,7 @@ def get_overriding_gid(self) -> Optional[int]:
return self.main_gid

@override
def get_supplementary_gids(self) -> Optional[list[int]]:
def get_supplementary_gids(self) -> set[int]:
return self.supplementary_gids

def _kernel_resource_spec_read(self, filename):
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ class KernelCreationConfig(TypedDict):
cluster_hostname: str # the kernel's hostname in the cluster
uid: Optional[int]
main_gid: Optional[int]
supplementary_gids: Optional[list[int]]
supplementary_gids: list[int]
resource_slots: Mapping[str, str] # json form of ResourceSlot
resource_opts: Mapping[str, str] # json form of resource options
environ: Mapping[str, str]
Expand Down Expand Up @@ -1125,7 +1125,7 @@ class KernelEnqueueingConfig(TypedDict):
startup_command: Optional[str]
uid: Optional[int]
main_gid: Optional[int]
supplementary_gids: Optional[list[int]]
supplementary_gids: list[int]


def _stringify_number(v: Union[BinarySize, int, float, Decimal]) -> str:
Expand Down
15 changes: 11 additions & 4 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,11 @@ async def create_session(
script, _ = await query_bootstrap_script(conn, owner_access_key)
bootstrap_script = script

user_row = await db_sess.scalar(
sa.select(UserRow).where(UserRow.uuid == user_scope.user_uuid)
)
user_row = cast(UserRow, user_row)

public_sgroup_only = session_type not in PRIVATE_SESSION_TYPES
if dry_run:
return {}
Expand All @@ -619,9 +624,11 @@ async def create_session(
"creation_config": config,
"kernel_configs": [
{
"uid": sess.user.container_uid,
"main_gid": sess.user.container_main_gid,
"supplementary_gids": sess.user.container_supplementary_gids,
"uid": user_row.container_uid,
"main_gid": user_row.container_main_gid,
"supplementary_gids": (
user_row.container_supplementary_gids or []
),
"image_ref": image_ref,
"cluster_role": DEFAULT_ROLE,
"cluster_idx": 1,
Expand Down Expand Up @@ -1847,7 +1854,7 @@ def get_image_conf(kernel: KernelRow) -> ImageConfig:
"cluster_hostname": binding.kernel.cluster_hostname,
"uid": binding.kernel.uid,
"main_gid": binding.kernel.main_gid,
"supplementary_gids": binding.kernel.supplementary_gids,
"supplementary_gids": binding.kernel.supplementary_gids or [],
"idle_timeout": int(idle_timeout),
"mounts": [item.to_json() for item in scheduled_session.vfolder_mounts],
"environ": {
Expand Down

0 comments on commit b630681

Please sign in to comment.