diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 8e3a5dbd7d..1087bcb1c2 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -576,8 +576,8 @@ def get_overriding_uid(self) -> Optional[int]: def get_overriding_gid(self) -> Optional[int]: return None - def get_supplementary_gids(self) -> Optional[list[int]]: - return None + def get_supplementary_gids(self) -> set[int]: + return set() KernelCreationContextType = TypeVar( @@ -1881,14 +1881,17 @@ async def create_kernel( uid = self.local_config["container"]["kernel-uid"] environ["LOCAL_USER_ID"] = str(uid) + sgids = set(ctx.get_supplementary_gids() or []) + kernel_gid: int = self.local_config["container"]["kernel-gid"] if (ogid := ctx.get_overriding_gid()) is not None: environ["LOCAL_GROUP_ID"] = str(ogid) + if KernelFeatures.UID_MATCH in ctx.kernel_features: + sgids.add(kernel_gid) else: if KernelFeatures.UID_MATCH in ctx.kernel_features: - gid = self.local_config["container"]["kernel-gid"] - environ["LOCAL_GROUP_ID"] = str(gid) + environ["LOCAL_GROUP_ID"] = str(kernel_gid) - update_additional_gids(environ, ctx.get_supplementary_gids() or []) + update_additional_gids(environ, sgids) environ.update( await ctx.get_extra_envs(), ) diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index e8bda69dfc..0b5af3e243 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -237,7 +237,7 @@ def __init__( self.work_dir = scratch_dir / "work" self.uid = kernel_config["uid"] self.main_gid = kernel_config["main_gid"] - self.supplementary_gids = kernel_config["supplementary_gids"] + self.supplementary_gids = set(kernel_config["supplementary_gids"]) self.port_pool = port_pool self.agent_sockpath = agent_sockpath @@ -261,7 +261,7 @@ def get_overriding_gid(self) -> Optional[int]: return self.main_gid @override - def get_supplementary_gids(self) -> Optional[list[int]]: + def get_supplementary_gids(self) -> set[int]: return self.supplementary_gids def _kernel_resource_spec_read(self, filename): diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index 8eb7f9c017..0213365ca5 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -1092,7 +1092,7 @@ class KernelCreationConfig(TypedDict): cluster_hostname: str # the kernel's hostname in the cluster uid: Optional[int] main_gid: Optional[int] - supplementary_gids: Optional[list[int]] + supplementary_gids: list[int] resource_slots: Mapping[str, str] # json form of ResourceSlot resource_opts: Mapping[str, str] # json form of resource options environ: Mapping[str, str] @@ -1125,7 +1125,7 @@ class KernelEnqueueingConfig(TypedDict): startup_command: Optional[str] uid: Optional[int] main_gid: Optional[int] - supplementary_gids: Optional[list[int]] + supplementary_gids: list[int] def _stringify_number(v: Union[BinarySize, int, float, Decimal]) -> str: diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index c6fb26a37d..3a590413f9 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -605,6 +605,11 @@ async def create_session( script, _ = await query_bootstrap_script(conn, owner_access_key) bootstrap_script = script + user_row = await db_sess.scalar( + sa.select(UserRow).where(UserRow.uuid == user_scope.user_uuid) + ) + user_row = cast(UserRow, user_row) + public_sgroup_only = session_type not in PRIVATE_SESSION_TYPES if dry_run: return {} @@ -619,9 +624,11 @@ async def create_session( "creation_config": config, "kernel_configs": [ { - "uid": sess.user.container_uid, - "main_gid": sess.user.container_main_gid, - "supplementary_gids": sess.user.container_supplementary_gids, + "uid": user_row.container_uid, + "main_gid": user_row.container_main_gid, + "supplementary_gids": ( + user_row.container_supplementary_gids or [] + ), "image_ref": image_ref, "cluster_role": DEFAULT_ROLE, "cluster_idx": 1, @@ -1847,7 +1854,7 @@ def get_image_conf(kernel: KernelRow) -> ImageConfig: "cluster_hostname": binding.kernel.cluster_hostname, "uid": binding.kernel.uid, "main_gid": binding.kernel.main_gid, - "supplementary_gids": binding.kernel.supplementary_gids, + "supplementary_gids": binding.kernel.supplementary_gids or [], "idle_timeout": int(idle_timeout), "mounts": [item.to_json() for item in scheduled_session.vfolder_mounts], "environ": {