Skip to content

Commit

Permalink
refactor: Resolve some pyright errors (#3299) (#3301)
Browse files Browse the repository at this point in the history
Co-authored-by: Joongi Kim <[email protected]>
Co-authored-by: Jeongseok Kang <[email protected]>
  • Loading branch information
3 people authored Dec 26, 2024
1 parent 650e052 commit 60e2d23
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 20 deletions.
5 changes: 3 additions & 2 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def get_runner_mount(
type: MountTypes,
src: Union[str, Path],
target: Union[str, Path],
perm: Literal["ro", "rw"] = "ro",
perm: MountPermission = MountPermission.READ_ONLY,
opts: Optional[Mapping[str, Any]] = None,
):
"""
Expand Down Expand Up @@ -426,7 +426,7 @@ def _mount(
type,
src,
dst,
MountPermission("ro"),
MountPermission.READ_ONLY,
),
)

Expand Down Expand Up @@ -2744,6 +2744,7 @@ async def handle_volume_umount(
volume_mount_prefix = context.local_config["agent"]["mount-path"]
real_path = Path(volume_mount_prefix, event.dir_name)
err_msg: str | None = None
did_umount = False
try:
did_umount = await umount(
str(real_path),
Expand Down
9 changes: 7 additions & 2 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Final,
FrozenSet,
List,
Literal,
MutableMapping,
Optional,
Sequence,
Expand Down Expand Up @@ -494,7 +493,7 @@ def get_runner_mount(
type: MountTypes,
src: Union[str, Path],
target: Union[str, Path],
perm: Literal["ro", "rw"] = "ro",
perm: MountPermission = MountPermission.READ_ONLY,
opts: Optional[Mapping[str, Any]] = None,
) -> Mount:
return Mount(
Expand Down Expand Up @@ -1034,6 +1033,8 @@ async def _rollback_container_creation() -> None:
ctnr_host_port_map: MutableMapping[int, int] = {}
stdin_port = 0
stdout_port = 0
repl_in_port = 0
repl_out_port = 0
for idx, port in enumerate(exposed_ports):
if container_config["HostConfig"].get("NetworkMode") == "host":
host_port = host_ports[idx]
Expand Down Expand Up @@ -1069,6 +1070,8 @@ async def _rollback_container_creation() -> None:
if container_config["HostConfig"].get("NetworkMode") == "host":
sport["container_ports"] = created_host_ports

assert repl_in_port != 0, "repl_in_port should have been assigned."
assert repl_out_port != 0, "repl_out_port should have been assigned."
return {
"container_id": container._id,
"kernel_host": advertised_kernel_host or container_bind_host,
Expand Down Expand Up @@ -1225,6 +1228,8 @@ def get_cgroup_path(self, controller: str, container_id: str) -> Path:
cgroup = f"docker/{container_id}"
case "systemd":
cgroup = f"system.slice/docker-{container_id}.scope"
case _:
raise ValueError(f"Unsupported cgroup driver: {driver!r}")
return mount_point / cgroup

async def load_resources(self) -> Mapping[DeviceName, AbstractComputePlugin]:
Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/agent/dummy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ImageRegistry,
KernelCreationConfig,
KernelId,
MountPermission,
MountTypes,
ResourceSlot,
ServicePort,
Expand Down Expand Up @@ -155,7 +156,7 @@ def get_runner_mount(
type: MountTypes,
src: str | Path,
target: str | Path,
perm: Literal["ro", "rw"] = "ro",
perm: MountPermission = MountPermission.READ_ONLY,
opts: Optional[Mapping[str, Any]] = None,
):
return Mount(MountTypes.BIND, Path(), Path())
Expand Down
7 changes: 6 additions & 1 deletion src/ai/backend/agent/dummy/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,12 @@ def aggregate_console(
return

async def get_next_result(self, api_ver=2, flush_timeout=2.0) -> NextResult:
return {}
return {
"runId": self.current_run_id,
"status": "finished",
"exitCode": None,
"options": None,
}

async def attach_output_queue(self, run_id: str | None) -> None:
return
Expand Down
17 changes: 9 additions & 8 deletions src/ai/backend/agent/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
List,
Literal,
Mapping,
NotRequired,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -146,18 +147,18 @@ class ResultRecord:
data: Optional[str] = None


class NextResult(TypedDict, total=False):
class NextResult(TypedDict):
runId: Optional[str]
status: ResultType
exitCode: Optional[int]
options: Optional[Mapping[str, Any]]
# v1
stdout: Optional[str]
stderr: Optional[str]
media: Optional[Sequence[Any]]
html: Optional[Sequence[Any]]
stdout: NotRequired[Optional[str]]
stderr: NotRequired[Optional[str]]
media: NotRequired[Sequence[Any]]
html: NotRequired[Sequence[Any]]
# v2
console: Optional[Sequence[Any]]
console: NotRequired[Sequence[Any]]


class AbstractKernel(UserDict, aobject, metaclass=ABCMeta):
Expand Down Expand Up @@ -813,9 +814,9 @@ def aggregate_console(
async def get_next_result(self, api_ver=2, flush_timeout=2.0) -> NextResult:
# Context: per API request
has_continuation = ClientFeatures.CONTINUATION in self.client_features
records = []
result: NextResult
try:
records = []
result: NextResult
assert self.output_queue is not None
with timeout(flush_timeout if has_continuation else None):
while True:
Expand Down
3 changes: 1 addition & 2 deletions src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Any,
FrozenSet,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Expand Down Expand Up @@ -412,7 +411,7 @@ def get_runner_mount(
type: MountTypes,
src: Union[str, Path],
target: Union[str, Path],
perm: Literal["ro", "rw"] = "ro",
perm: MountPermission = MountPermission.READ_ONLY,
opts: Optional[Mapping[str, Any]] = None,
) -> Mount:
return Mount(
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/common/enum_extension.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum

class StringSetFlag(enum.Flag):
class StringSetFlag(enum.StrEnum):
def __eq__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
def __or__( # type: ignore[override]
Expand Down
4 changes: 1 addition & 3 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@
from .docker import ImageRef


T_aobj = TypeVar("T_aobj", bound="aobject")

current_resource_slots: ContextVar[Mapping[SlotName, SlotTypes]] = ContextVar(
"current_resource_slots"
)
Expand All @@ -119,7 +117,7 @@ class aobject(object):
"""

@classmethod
async def new(cls: Type[T_aobj], *args, **kwargs) -> T_aobj:
async def new(cls: Type[Self], *args, **kwargs) -> Self:
"""
We can do ``await SomeAObject(...)``, but this makes mypy
to complain about its return type with ``await`` statement.
Expand Down

0 comments on commit 60e2d23

Please sign in to comment.