diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 11b893d27e..907a31e527 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -326,7 +326,7 @@ def get_runner_mount( type: MountTypes, src: Union[str, Path], target: Union[str, Path], - perm: Literal["ro", "rw"] = "ro", + perm: MountPermission = MountPermission.READ_ONLY, opts: Optional[Mapping[str, Any]] = None, ): """ @@ -428,7 +428,7 @@ def _mount( type, src, dst, - MountPermission("ro"), + MountPermission.READ_ONLY, ), ) @@ -2781,6 +2781,7 @@ async def handle_volume_umount( volume_mount_prefix = context.local_config["agent"]["mount-path"] real_path = Path(volume_mount_prefix, event.dir_name) err_msg: str | None = None + did_umount = False try: did_umount = await umount( str(real_path), diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index 2f62d3804f..3153d88b9f 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -24,7 +24,6 @@ Final, FrozenSet, List, - Literal, MutableMapping, Optional, Sequence, @@ -500,7 +499,7 @@ def get_runner_mount( type: MountTypes, src: Union[str, Path], target: Union[str, Path], - perm: Literal["ro", "rw"] = "ro", + perm: MountPermission = MountPermission.READ_ONLY, opts: Optional[Mapping[str, Any]] = None, ) -> Mount: return Mount( @@ -1050,6 +1049,8 @@ async def _rollback_container_creation() -> None: ) created_host_ports: Tuple[int, ...] + repl_in_port = 0 + repl_out_port = 0 if container_network_info: kernel_host = container_network_info.container_host port_map = container_network_info.services @@ -1100,6 +1101,8 @@ async def _rollback_container_creation() -> None: ) sport["host_ports"] = created_host_ports + assert repl_in_port != 0, "repl_in_port should have been assigned." + assert repl_out_port != 0, "repl_out_port should have been assigned." return { "container_id": container._id, "kernel_host": kernel_host, @@ -1265,6 +1268,8 @@ def get_cgroup_path(self, controller: str, container_id: str) -> Path: cgroup = f"docker/{container_id}" case "systemd": cgroup = f"system.slice/docker-{container_id}.scope" + case _: + raise ValueError(f"Unsupported cgroup driver: {driver!r}") return mount_point / cgroup async def load_resources(self) -> Mapping[DeviceName, AbstractComputePlugin]: diff --git a/src/ai/backend/agent/dummy/agent.py b/src/ai/backend/agent/dummy/agent.py index 0f5eb729db..9450a0d34c 100644 --- a/src/ai/backend/agent/dummy/agent.py +++ b/src/ai/backend/agent/dummy/agent.py @@ -28,6 +28,7 @@ ImageRegistry, KernelCreationConfig, KernelId, + MountPermission, MountTypes, ResourceSlot, ServicePort, @@ -155,7 +156,7 @@ def get_runner_mount( type: MountTypes, src: str | Path, target: str | Path, - perm: Literal["ro", "rw"] = "ro", + perm: MountPermission = MountPermission.READ_ONLY, opts: Optional[Mapping[str, Any]] = None, ): return Mount(MountTypes.BIND, Path(), Path()) diff --git a/src/ai/backend/agent/dummy/kernel.py b/src/ai/backend/agent/dummy/kernel.py index c9935614b2..b47c83e9e7 100644 --- a/src/ai/backend/agent/dummy/kernel.py +++ b/src/ai/backend/agent/dummy/kernel.py @@ -316,7 +316,12 @@ def aggregate_console( return async def get_next_result(self, api_ver=2, flush_timeout=2.0) -> NextResult: - return {} + return { + "runId": self.current_run_id, + "status": "finished", + "exitCode": None, + "options": None, + } async def attach_output_queue(self, run_id: str | None) -> None: return diff --git a/src/ai/backend/agent/kernel.py b/src/ai/backend/agent/kernel.py index edb7e287b2..198ebe18f3 100644 --- a/src/ai/backend/agent/kernel.py +++ b/src/ai/backend/agent/kernel.py @@ -20,6 +20,7 @@ List, Literal, Mapping, + NotRequired, Optional, Sequence, Set, @@ -146,18 +147,18 @@ class ResultRecord: data: Optional[str] = None -class NextResult(TypedDict, total=False): +class NextResult(TypedDict): runId: Optional[str] status: ResultType exitCode: Optional[int] options: Optional[Mapping[str, Any]] # v1 - stdout: Optional[str] - stderr: Optional[str] - media: Optional[Sequence[Any]] - html: Optional[Sequence[Any]] + stdout: NotRequired[Optional[str]] + stderr: NotRequired[Optional[str]] + media: NotRequired[Sequence[Any]] + html: NotRequired[Sequence[Any]] # v2 - console: Optional[Sequence[Any]] + console: NotRequired[Sequence[Any]] class AbstractKernel(UserDict, aobject, metaclass=ABCMeta): @@ -816,9 +817,9 @@ def aggregate_console( async def get_next_result(self, api_ver=2, flush_timeout=2.0) -> NextResult: # Context: per API request has_continuation = ClientFeatures.CONTINUATION in self.client_features + records = [] + result: NextResult try: - records = [] - result: NextResult assert self.output_queue is not None with timeout(flush_timeout if has_continuation else None): while True: diff --git a/src/ai/backend/agent/kubernetes/agent.py b/src/ai/backend/agent/kubernetes/agent.py index 78bf1858e6..8db627eb48 100644 --- a/src/ai/backend/agent/kubernetes/agent.py +++ b/src/ai/backend/agent/kubernetes/agent.py @@ -16,7 +16,6 @@ Any, FrozenSet, List, - Literal, Mapping, MutableMapping, Optional, @@ -412,7 +411,7 @@ def get_runner_mount( type: MountTypes, src: Union[str, Path], target: Union[str, Path], - perm: Literal["ro", "rw"] = "ro", + perm: MountPermission = MountPermission.READ_ONLY, opts: Optional[Mapping[str, Any]] = None, ) -> Mount: return Mount( diff --git a/src/ai/backend/common/enum_extension.pyi b/src/ai/backend/common/enum_extension.pyi index 4643c3fc6e..f3a5cae0ce 100644 --- a/src/ai/backend/common/enum_extension.pyi +++ b/src/ai/backend/common/enum_extension.pyi @@ -1,6 +1,6 @@ import enum -class StringSetFlag(enum.Flag): +class StringSetFlag(enum.StrEnum): def __eq__(self, other: object) -> bool: ... def __hash__(self) -> int: ... def __or__( # type: ignore[override] diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index 85b3db898d..71e388f24a 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -101,8 +101,6 @@ from .docker import ImageRef -T_aobj = TypeVar("T_aobj", bound="aobject") - current_resource_slots: ContextVar[Mapping[SlotName, SlotTypes]] = ContextVar( "current_resource_slots" ) @@ -121,7 +119,7 @@ class aobject(object): """ @classmethod - async def new(cls: Type[T_aobj], *args, **kwargs) -> T_aobj: + async def new(cls: Type[Self], *args, **kwargs) -> Self: """ We can do ``await SomeAObject(...)``, but this makes mypy to complain about its return type with ``await`` statement.