Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable per-user UID/GID set for containers #3279

Open
wants to merge 7 commits into
base: topic/12-20-feat_add_uid_and_gid_columns_to_users_and_kernels_tables
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 54 additions & 12 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
import zlib
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from collections.abc import (
Iterable,
Mapping,
MutableMapping,
MutableSequence,
Sequence,
)
from decimal import Decimal
from io import SEEK_END, BytesIO
from pathlib import Path
Expand All @@ -33,11 +40,7 @@
Generic,
List,
Literal,
Mapping,
MutableMapping,
MutableSequence,
Optional,
Sequence,
Set,
Tuple,
Type,
Expand Down Expand Up @@ -192,6 +195,17 @@
KernelObjectType = TypeVar("KernelObjectType", bound=AbstractKernel)


def update_additional_gids(environ: MutableMapping[str, str], gids: Iterable[int]) -> None:
if not gids:
return
if orig_additional_gids := environ.get("ADDITIONAL_GIDS"):
orig_add_gids = {int(gid) for gid in orig_additional_gids.split(",") if gid}
additional_gids = orig_add_gids | set(gids)
else:
additional_gids = set(gids)
environ["ADDITIONAL_GIDS"] = ",".join(map(str, additional_gids))


class AbstractKernelCreationContext(aobject, Generic[KernelObjectType]):
kspec_version: int
distro: str
Expand Down Expand Up @@ -219,6 +233,7 @@ def __init__(
distro: str,
local_config: Mapping[str, Any],
computers: MutableMapping[DeviceName, ComputerContext],
proc_uid: int,
restarting: bool = False,
) -> None:
self.image_labels = kernel_config["image"]["labels"]
Expand All @@ -237,6 +252,7 @@ def __init__(
self.computers = computers
self.restarting = restarting
self.local_config = local_config
self.proc_uid = proc_uid

@abstractmethod
async def get_extra_envs(self) -> Mapping[str, str]:
Expand Down Expand Up @@ -514,8 +530,8 @@ def mount_static_binary(filename: str, target_path: str) -> None:
environ["LD_PRELOAD"] = "/opt/kernel/libbaihook.so"

# Inject ComputeDevice-specific env-varibles and hooks
already_injected_hooks: Set[Path] = set()
additional_gid_set: Set[int] = set()
already_injected_hooks: set[Path] = set()
additional_gid_set: set[int] = set()

for dev_type, device_alloc in resource_spec.allocations.items():
computer_ctx = self.computers[dev_type]
Expand Down Expand Up @@ -552,7 +568,16 @@ def mount_static_binary(filename: str, target_path: str) -> None:
environ["LD_PRELOAD"] += ":" + container_hook_path
already_injected_hooks.add(hook_path)

environ["ADDITIONAL_GIDS"] = ",".join(map(str, additional_gid_set))
update_additional_gids(environ, additional_gids)

def get_overriding_uid(self) -> Optional[int]:
return None

def get_overriding_gid(self) -> Optional[int]:
return None

def get_supplementary_gids(self) -> set[int]:
return set()


KernelCreationContextType = TypeVar(
Expand Down Expand Up @@ -586,6 +611,7 @@ class AbstractAgent(
computers: MutableMapping[DeviceName, ComputerContext]
images: Mapping[str, str]
port_pool: Set[int]
proc_uid: int

redis: Redis

Expand Down Expand Up @@ -640,6 +666,7 @@ def __init__(
local_config["container"]["port-range"][1] + 1,
)
)
self.proc_uid = os.geteuid()
self.stats_monitor = stats_monitor
self.error_monitor = error_monitor
self._pending_creation_tasks = defaultdict(set)
Expand Down Expand Up @@ -1714,6 +1741,7 @@ async def init_kernel_context(
kernel_image: ImageRef,
kernel_config: KernelCreationConfig,
*,
proc_uid: int,
restarting: bool = False,
cluster_ssh_port_mapping: Optional[ClusterSSHPortMapping] = None,
) -> AbstractKernelCreationContext:
Expand Down Expand Up @@ -1840,16 +1868,30 @@ async def create_kernel(
kernel_image,
kernel_config,
restarting=restarting,
proc_uid=self.proc_uid,
cluster_ssh_port_mapping=cluster_info.get("cluster_ssh_port_mapping"),
)
environ: dict[str, str] = {**kernel_config["environ"]}

# Inject Backend.AI-intrinsic env-variables for gosu
if KernelFeatures.UID_MATCH in ctx.kernel_features:
uid = self.local_config["container"]["kernel-uid"]
gid = self.local_config["container"]["kernel-gid"]
environ["LOCAL_USER_ID"] = str(uid)
environ["LOCAL_GROUP_ID"] = str(gid)
if (ouid := ctx.get_overriding_uid()) is not None:
environ["LOCAL_USER_ID"] = str(ouid)
else:
if KernelFeatures.UID_MATCH in ctx.kernel_features:
uid = self.local_config["container"]["kernel-uid"]
environ["LOCAL_USER_ID"] = str(uid)

sgids = set(ctx.get_supplementary_gids() or [])
kernel_gid: int = self.local_config["container"]["kernel-gid"]
if (ogid := ctx.get_overriding_gid()) is not None:
environ["LOCAL_GROUP_ID"] = str(ogid)
if KernelFeatures.UID_MATCH in ctx.kernel_features:
sgids.add(kernel_gid)
else:
if KernelFeatures.UID_MATCH in ctx.kernel_features:
environ["LOCAL_GROUP_ID"] = str(kernel_gid)

update_additional_gids(environ, sgids)
environ.update(
await ctx.get_extra_envs(),
)
Expand Down
Loading
Loading