diff --git a/changes/1838.feature.md b/changes/1838.feature.md new file mode 100644 index 0000000000..28dd235c0c --- /dev/null +++ b/changes/1838.feature.md @@ -0,0 +1 @@ +Allow overriding vfolder mount permissions in API calls and CLI commands to create new sessions, with addition of a generic parser of comma-separated "key=value" list for CLI args and API params diff --git a/src/ai/backend/client/cli/session/args.py b/src/ai/backend/client/cli/session/args.py index 92b989095e..a82f41193c 100644 --- a/src/ai/backend/client/cli/session/args.py +++ b/src/ai/backend/client/cli/session/args.py @@ -76,7 +76,7 @@ "-m", "--mount", "mount", - metavar="NAME[=PATH]", + metavar="NAME[=PATH] or NAME[:PATH]", type=str, multiple=True, help=( diff --git a/src/ai/backend/client/cli/session/execute.py b/src/ai/backend/client/cli/session/execute.py index 3bb7975e1a..fe7143bcd9 100644 --- a/src/ai/backend/client/cli/session/execute.py +++ b/src/ai/backend/client/cli/session/execute.py @@ -20,7 +20,7 @@ from ai.backend.cli.params import CommaSeparatedListType, RangeExprOptionType from ai.backend.cli.types import ExitCode from ai.backend.common.arch import DEFAULT_IMAGE_ARCH -from ai.backend.common.types import ClusterMode +from ai.backend.common.types import ClusterMode, MountExpression from ...compat import asyncio_run, current_loop from ...config import local_cache_path @@ -245,26 +245,35 @@ def prepare_env_arg(env: Sequence[str]) -> Mapping[str, str]: def prepare_mount_arg( - mount_args: Optional[Sequence[str]], -) -> Tuple[Sequence[str], Mapping[str, str]]: + mount_args: Optional[Sequence[str]] = None, + *, + escape: bool = True, +) -> Tuple[Sequence[str], Mapping[str, str], Mapping[str, Mapping[str, str]]]: """ Parse the list of mount arguments into a list of - vfolder name and in-container mount path pairs. + vfolder name and in-container mount path pairs, + followed by extra options. + + :param mount_args: A list of mount arguments such as + [ + "type=bind,source=/colon:path/test,target=/data", + "type=bind,source=/colon:path/abcd,target=/zxcv,readonly", + # simple formats are still supported + "vf-abcd:/home/work/zxcv", + ] """ mounts = set() mount_map = {} + mount_options = {} if mount_args is not None: - for value in mount_args: - if "=" in value: - sp = value.split("=", maxsplit=1) - elif ":" in value: # docker-like volume mount mapping - sp = value.split(":", maxsplit=1) - else: - sp = [value] - mounts.add(sp[0]) - if len(sp) == 2: - mount_map[sp[0]] = sp[1] - return list(mounts), mount_map + for mount_arg in mount_args: + mountpoint = {**MountExpression(mount_arg).parse(escape=escape)} + mount = str(mountpoint.pop("source")) + mounts.add(mount) + if target := mountpoint.pop("target", None): + mount_map[mount] = str(target) + mount_options[mount] = mountpoint + return list(mounts), mount_map, mount_options @main.command() @@ -448,7 +457,7 @@ def run( envs = prepare_env_arg(env) resources = prepare_resource_arg(resources) resource_opts = prepare_resource_arg(resource_opts) - mount, mount_map = prepare_mount_arg(mount) + mount, mount_map, mount_options = prepare_mount_arg(mount, escape=True) if env_range is None: env_range = [] # noqa @@ -628,6 +637,7 @@ async def _run(session, idx, name, envs, clean_cmd, build_cmd, exec_cmd, is_mult cluster_mode=cluster_mode, mounts=mount, mount_map=mount_map, + mount_options=mount_options, envs=envs, resources=resources, resource_opts=resource_opts, diff --git a/src/ai/backend/client/cli/session/lifecycle.py b/src/ai/backend/client/cli/session/lifecycle.py index 6e50af5ed2..110db94c41 100644 --- a/src/ai/backend/client/cli/session/lifecycle.py +++ b/src/ai/backend/client/cli/session/lifecycle.py @@ -37,7 +37,12 @@ from .. import events from ..pretty import print_done, print_error, print_fail, print_info, print_wait, print_warn from .args import click_start_option -from .execute import format_stats, prepare_env_arg, prepare_mount_arg, prepare_resource_arg +from .execute import ( + format_stats, + prepare_env_arg, + prepare_mount_arg, + prepare_resource_arg, +) from .ssh import container_ssh_ctx list_expr = CommaSeparatedListType() @@ -169,7 +174,7 @@ def create( envs = prepare_env_arg(env) parsed_resources = prepare_resource_arg(resources) parsed_resource_opts = prepare_resource_arg(resource_opts) - mount, mount_map = prepare_mount_arg(mount) + mount, mount_map, mount_options = prepare_mount_arg(mount, escape=True) preopen_ports = preopen assigned_agent_list = assign_agent @@ -189,6 +194,7 @@ def create( cluster_mode=cluster_mode, mounts=mount, mount_map=mount_map, + mount_options=mount_options, envs=envs, startup_command=startup_command, resources=parsed_resources, @@ -425,7 +431,7 @@ def create_from_template( if len(resource_opts) > 0 or no_resource else undefined ) - prepared_mount, prepared_mount_map = ( + prepared_mount, prepared_mount_map, _ = ( prepare_mount_arg(mount) if len(mount) > 0 or no_mount else (undefined, undefined) ) kwargs = { diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index 75f3afabbe..837f6b3ca3 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -176,6 +176,7 @@ async def get_or_create( callback_url: Optional[str] = None, mounts: List[str] = None, mount_map: Mapping[str, str] = None, + mount_options: Optional[Mapping[str, Mapping[str, str]]] = None, envs: Mapping[str, str] = None, startup_command: str = None, resources: Mapping[str, str | int] = None, @@ -238,6 +239,7 @@ async def get_or_create( If you want different paths, names should be absolute paths. The target mount path of vFolders should not overlap with the linux system folders. vFolders which has a dot(.) prefix in its name are not affected. + :param mount_options: Mapping which contains extra options for vfolder. :param envs: The environment variables which always bypasses the jail policy. :param resources: The resource specification. (TODO: details) :param cluster_size: The number of containers in this compute session. @@ -264,6 +266,8 @@ async def get_or_create( mounts = [] if mount_map is None: mount_map = {} + if mount_options is None: + mount_options = {} if resources is None: resources = {} if resource_opts is None: @@ -303,12 +307,14 @@ async def get_or_create( if assign_agent is not None: params["config"].update({ "mount_map": mount_map, + "mount_options": mount_options, "preopen_ports": preopen_ports, "agentList": assign_agent, }) else: params["config"].update({ "mount_map": mount_map, + "mount_options": mount_options, "preopen_ports": preopen_ports, }) if api_session.get().api_version >= (4, "20190615"): @@ -1201,6 +1207,7 @@ async def get_or_create( callback_url: Optional[str] = None, mounts: Optional[List[str]] = None, mount_map: Optional[Mapping[str, str]] = None, + mount_options: Optional[Mapping[str, Mapping[str, str]]] = None, envs: Optional[Mapping[str, str]] = None, startup_command: Optional[str] = None, resources: Optional[Mapping[str, str]] = None, diff --git a/src/ai/backend/common/models/minilang/BUILD b/src/ai/backend/common/models/minilang/BUILD new file mode 100644 index 0000000000..7357442404 --- /dev/null +++ b/src/ai/backend/common/models/minilang/BUILD @@ -0,0 +1 @@ +python_sources(name="src") diff --git a/src/ai/backend/common/models/minilang/mount.py b/src/ai/backend/common/models/minilang/mount.py new file mode 100644 index 0000000000..86db83c893 --- /dev/null +++ b/src/ai/backend/common/models/minilang/mount.py @@ -0,0 +1,60 @@ +from typing import Annotated, Mapping, Sequence, TypeAlias + +from lark import Lark, Transformer, lexer +from lark.exceptions import LarkError + +_grammar = r""" + start: pair ("," pair)* + pair: key [("="|":") value] + key: SLASH? CNAME (SEPARATOR|CNAME|DIGIT)* + value: SLASH? CNAME (SEPARATOR|CNAME|DIGIT)* | ESCAPED_STRING + + SEPARATOR: SLASH | "\\," | "\\=" | "\\:" | DASH + SLASH: "/" + DASH: "-" + + %import common.CNAME + %import common.DIGIT + %import common.ESCAPED_STRING + %import common.WS + %ignore WS +""" + +PairType: TypeAlias = tuple[str, str] + + +class DictTransformer(Transformer): + reserved_keys = frozenset({"type", "source", "target", "perm", "permission"}) + + def start(self, pairs: Sequence[PairType]) -> Mapping[str, str]: + if pairs[0][0] not in self.reserved_keys: # [["vf-000", "/home/work"]] + result = {"source": pairs[0][0]} + if target := pairs[0][1]: + result["target"] = target + return result + return dict(pairs) # [("type", "bind"), ("source", "vf-000"), ...] + + def pair(self, token: Annotated[Sequence[str], 2]) -> PairType: + return (token[0], token[1]) + + def key(self, token: list[lexer.Token]) -> str: + return "".join(token) + + def value(self, token: list[lexer.Token]) -> str: + return "".join(token) + + +_parser = Lark(_grammar, parser="lalr") + + +class MountPointParser: + def __init__(self) -> None: + self._parser = _parser + + def parse_mount(self, expr: str) -> Mapping[str, str]: + try: + ast = self._parser.parse(expr) + result = DictTransformer().transform(ast) + except LarkError as e: + raise ValueError(f"Virtual folder mount parsing error: {e}") + return result diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index ebeb67ee82..f6d9bf6de9 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from decimal import Decimal from ipaddress import ip_address, ip_network -from pathlib import PurePosixPath +from pathlib import Path, PurePosixPath from ssl import SSLContext from typing import ( TYPE_CHECKING, @@ -43,9 +43,11 @@ import trafaret as t import typeguard from aiohttp import Fingerprint +from pydantic import BaseModel, ConfigDict, Field from redis.asyncio import Redis from .exception import InvalidIpAddressValue +from .models.minilang.mount import MountPointParser __all__ = ( "aobject", @@ -73,6 +75,7 @@ "MountPermission", "MountPermissionLiteral", "MountTypes", + "MountPoint", "VFolderID", "QuotaScopeID", "VFolderUsageMode", @@ -382,6 +385,44 @@ class MountTypes(enum.StrEnum): K8S_HOSTPATH = "k8s-hostpath" +class MountPoint(BaseModel): + type: MountTypes = Field(default=MountTypes.BIND) + source: Path + target: Path | None = Field(default=None) + permission: MountPermission | None = Field(alias="perm", default=None) + + model_config = ConfigDict(populate_by_name=True) + + +class MountExpression: + def __init__(self, expression: str, *, escape_map: Optional[Mapping[str, str]] = None) -> None: + self.expression = expression + self.escape_map = { + "\\,": ",", + "\\:": ":", + "\\=": "=", + } + if escape_map is not None: + self.escape_map.update(escape_map) + # self.unescape_map = {v: k for k, v in self.escape_map.items()} + + def __str__(self) -> str: + return self.expression + + def __repr__(self) -> str: + return self.__str__() + + def parse(self, *, escape: bool = True) -> Mapping[str, str]: + parser = MountPointParser() + result = {**parser.parse_mount(self.expression)} + if escape: + for key, value in result.items(): + for raw, alternative in self.escape_map.items(): + if raw in value: + result[key] = value.replace(raw, alternative) + return MountPoint(**result).model_dump() # type: ignore[arg-type] + + class HostPortPair(namedtuple("HostPortPair", "host port")): def as_sockaddr(self) -> Tuple[str, int]: return str(self.host), self.port diff --git a/src/ai/backend/manager/api/exceptions.py b/src/ai/backend/manager/api/exceptions.py index bc762643ba..3a1166f7ba 100644 --- a/src/ai/backend/manager/api/exceptions.py +++ b/src/ai/backend/manager/api/exceptions.py @@ -303,6 +303,11 @@ class VFolderFilterStatusNotAvailable(BackendError, web.HTTPBadRequest): error_title = "There is no available virtual folder to filter its status." +class VFolderPermissionError(BackendError, web.HTTPBadRequest): + error_type = "https://api.backend.ai/probs/vfolder-permission-error" + error_title = "The virtual folder does not permit the specified permission." + + class DotfileCreationFailed(BackendError, web.HTTPBadRequest): error_type = "https://api.backend.ai/probs/generic-bad-request" error_title = "Dotfile creation has failed." diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index 97ca0cc027..823fd7365a 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -31,7 +31,7 @@ type Queries { """Added since 24.03.0. Available values: GENERAL, MODEL_STORE""" type: [String] = ["GENERAL"] ): [Group] - image(reference: String!, architecture: String = "aarch64"): Image + image(reference: String!, architecture: String = "x86_64"): Image images(is_installed: Boolean, is_operation: Boolean): [Image] user(domain_name: String, email: String): User user_from_uuid(domain_name: String, user_id: ID): User @@ -973,9 +973,9 @@ type Mutations { rescan_images(registry: String): RescanImages preload_image(references: [String]!, target_agents: [String]!): PreloadImage unload_image(references: [String]!, target_agents: [String]!): UnloadImage - modify_image(architecture: String = "aarch64", props: ModifyImageInput!, target: String!): ModifyImage - forget_image(architecture: String = "aarch64", reference: String!): ForgetImage - alias_image(alias: String!, architecture: String = "aarch64", target: String!): AliasImage + modify_image(architecture: String = "x86_64", props: ModifyImageInput!, target: String!): ModifyImage + forget_image(architecture: String = "x86_64", reference: String!): ForgetImage + alias_image(alias: String!, architecture: String = "x86_64", target: String!): AliasImage dealias_image(alias: String!): DealiasImage clear_images(registry: String): ClearImages create_keypair_resource_policy(name: String!, props: CreateKeyPairResourcePolicyInput!): CreateKeyPairResourcePolicy @@ -1595,4 +1595,4 @@ input ImageRefType { name: String! registry: String architecture: String -} +} \ No newline at end of file diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 1928877fc6..ab6d3aa1ad 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -55,7 +55,15 @@ from ai.backend.common.exception import UnknownImageReference from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.plugin.monitor import GAUGE -from ai.backend.common.types import AccessKey, AgentId, ClusterMode, SessionTypes, VFolderID +from ai.backend.common.types import ( + AccessKey, + AgentId, + ClusterMode, + MountPermission, + MountTypes, + SessionTypes, + VFolderID, +) from ..config import DEFAULT_CHUNK_SIZE from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE @@ -186,6 +194,14 @@ def check_and_return(self, value: Any) -> object: creation_config_v5 = t.Dict({ t.Key("mounts", default=None): t.Null | t.List(t.String), tx.AliasedKey(["mount_map", "mountMap"], default=None): t.Null | t.Mapping(t.String, t.String), + tx.AliasedKey(["mount_options", "mountOptions"], default=None): t.Null + | t.Mapping( + t.String, + t.Dict({ + t.Key("type", default=MountTypes.BIND): tx.Enum(MountTypes), + tx.AliasedKey(["permission", "perm"], default=None): t.Null | tx.Enum(MountPermission), + }).ignore_extra("*"), + ), t.Key("environ", default=None): t.Null | t.Mapping(t.String, t.String), # cluster_size is moved to the root-level parameters tx.AliasedKey(["scaling_group", "scalingGroup"], default=None): t.Null | t.String, diff --git a/src/ai/backend/manager/models/vfolder.py b/src/ai/backend/manager/models/vfolder.py index 3e7938756f..ca2c4da384 100644 --- a/src/ai/backend/manager/models/vfolder.py +++ b/src/ai/backend/manager/models/vfolder.py @@ -29,6 +29,7 @@ from ai.backend.common.config import model_definition_iv from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.types import ( + MountPermission, QuotaScopeID, QuotaScopeType, VFolderHostPermission, @@ -38,7 +39,12 @@ VFolderUsageMode, ) -from ..api.exceptions import InvalidAPIParameters, VFolderNotFound, VFolderOperationFailed +from ..api.exceptions import ( + InvalidAPIParameters, + VFolderNotFound, + VFolderOperationFailed, + VFolderPermissionError, +) from ..defs import ( DEFAULT_CHUNK_SIZE, RESERVED_VFOLDER_PATTERNS, @@ -646,6 +652,7 @@ async def prepare_vfolder_mounts( resource_policy: Mapping[str, Any], requested_mount_references: Sequence[str | uuid.UUID], requested_mount_reference_map: Mapping[str | uuid.UUID, str], + requested_mount_reference_options: Mapping[str | uuid.UUID, Any], ) -> Sequence[VFolderMount]: """ Determine the actual mount information from the requested vfolder lists, @@ -657,6 +664,11 @@ async def prepare_vfolder_mounts( requested_mount_map: dict[str, str] = { name: path for name, path in requested_mount_reference_map.items() if isinstance(name, str) } + requested_mount_options: dict[str, dict[str, Any]] = { + name: options + for name, options in requested_mount_reference_options.items() + if isinstance(name, str) + } vfolder_ids_to_resolve = [ vfid for vfid in requested_mount_references if isinstance(vfid, uuid.UUID) @@ -672,6 +684,8 @@ async def prepare_vfolder_mounts( requested_mounts.append(name) if path := requested_mount_reference_map.get(vfid): requested_mount_map[name] = path + if options := requested_mount_reference_options.get(vfid): + requested_mount_options[name] = options requested_vfolder_names: dict[str, str] = {} requested_vfolder_subpaths: dict[str, str] = {} @@ -798,6 +812,18 @@ async def prepare_vfolder_mounts( kernel_path = PurePosixPath(kernel_path_raw) if not kernel_path.is_absolute(): kernel_path = PurePosixPath("/home/work", kernel_path_raw) + match requested_perm := requested_mount_options[key]["permission"]: + case MountPermission.READ_ONLY: + mount_perm = MountPermission.READ_ONLY + case MountPermission.READ_WRITE | MountPermission.RW_DELETE: + if vfolder["permission"] == VFolderPermission.READ_ONLY: + raise VFolderPermissionError( + f"VFolder {vfolder_name} is allowed to be accessed in '{vfolder['permission'].value}' mode, " + f"but attempted with '{requested_perm.value}' mode." + ) + mount_perm = requested_perm + case _: # None if unset + mount_perm = vfolder["permission"] matched_vfolder_mounts.append( VFolderMount( name=vfolder["name"], @@ -805,7 +831,7 @@ async def prepare_vfolder_mounts( vfsubpath=PurePosixPath(requested_vfolder_subpaths[key]), host_path=mount_base_path / requested_vfolder_subpaths[key], kernel_path=kernel_path, - mount_perm=vfolder["permission"], + mount_perm=mount_perm, usage_mode=vfolder["usage_mode"], ) ) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 105d818c45..362b806f64 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -911,9 +911,12 @@ async def enqueue_session( ) use_host_network_result = await conn.execute(use_host_network_query) use_host_network = use_host_network_result.scalar() - # Translate mounts/mount_map into vfolder mounts + # Translate mounts/mount_map/mount_options into vfolder mounts requested_mounts = session_enqueue_configs["creation_config"].get("mounts") or [] requested_mount_map = session_enqueue_configs["creation_config"].get("mount_map") or {} + requested_mount_options = ( + session_enqueue_configs["creation_config"].get("mount_options") or {} + ) allowed_vfolder_types = await self.shared_config.get_vfolder_types() vfolder_mounts = await prepare_vfolder_mounts( conn, @@ -923,6 +926,7 @@ async def enqueue_session( resource_policy, requested_mounts, requested_mount_map, + requested_mount_options, ) # Prepare internal data for common dotfiles. diff --git a/tests/client/cli/test_mount.py b/tests/client/cli/test_mount.py new file mode 100644 index 0000000000..82a4fe52ae --- /dev/null +++ b/tests/client/cli/test_mount.py @@ -0,0 +1,82 @@ +from ai.backend.client.cli.session.execute import prepare_mount_arg +from ai.backend.common.types import MountPermission, MountTypes + + +def test_vfolder_mount(): + # given + mount = [ + "type=bind,source=/colon\\:path/test,target=/data", + "type=bind,source=/usr/abcd,target=/home/work/zxcv,perm=ro", + "type=bind,source=/usr/lorem,target=/home/work/ipsum,permission=ro", + "source=/src/hello,target=/trg/hello,perm=rw", + ] + + # when + mount, mount_map, mount_options = prepare_mount_arg(mount) + + # then + assert set(mount) == {"/colon:path/test", "/usr/abcd", "/usr/lorem", "/src/hello"} + assert mount_map == { + "/colon:path/test": "/data", + "/usr/abcd": "/home/work/zxcv", + "/usr/lorem": "/home/work/ipsum", + "/src/hello": "/trg/hello", + } + assert mount_options == { + "/colon:path/test": { + "type": MountTypes.BIND, + "permission": None, + }, + "/usr/abcd": { + "type": MountTypes.BIND, + "permission": MountPermission.READ_ONLY, + }, + "/usr/lorem": { + "type": MountTypes.BIND, + "permission": MountPermission.READ_ONLY, + }, + "/src/hello": { + "type": MountTypes.BIND, + "permission": MountPermission.READ_WRITE, + }, + } + + +def test_vfolder_mount_without_target(): + # given + mount = [ + "type=volume,source=vf-dd244f7f,perm=ro", + ] + + # when + mount, mount_map, mount_options = prepare_mount_arg(mount) + + # then + assert set(mount) == {"vf-dd244f7f"} + assert mount_map == {} + assert mount_options == { + "vf-dd244f7f": { + "type": MountTypes.VOLUME, + "permission": MountPermission.READ_ONLY, + }, + } + + +def test_vfolder_mount__edge_cases_with(): + # given + mount = [ + "type=bind,source=vf-abc\\,zxc,target=/home/work", # source with a comma + "type=bind,source=vf-abc\\=zxc,target=/home/work", # source with an equals sign + ] + + # when + mount_unescaped, *_ = prepare_mount_arg(mount, escape=False) + + # then + assert set(mount_unescaped) == {"vf-abc\\,zxc", "vf-abc\\=zxc"} + + # when + mount_escaped, *_ = prepare_mount_arg(mount, escape=True) + + # then + assert set(mount_escaped) == {"vf-abc,zxc", "vf-abc=zxc"}