Skip to content

Commit

Permalink
refactor: Resolve some pyright errors
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Dec 25, 2024
1 parent 6cf11dd commit 3c8aa13
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 17 deletions.
5 changes: 3 additions & 2 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def get_runner_mount(
type: MountTypes,
src: Union[str, Path],
target: Union[str, Path],
perm: Literal["ro", "rw"] = "ro",
perm: MountPermission = MountPermission.READ_ONLY,
opts: Optional[Mapping[str, Any]] = None,
):
"""
Expand Down Expand Up @@ -428,7 +428,7 @@ def _mount(
type,
src,
dst,
MountPermission("ro"),
MountPermission.READ_ONLY,
),
)

Expand Down Expand Up @@ -2781,6 +2781,7 @@ async def handle_volume_umount(
volume_mount_prefix = context.local_config["agent"]["mount-path"]
real_path = Path(volume_mount_prefix, event.dir_name)
err_msg: str | None = None
did_umount = False
try:
did_umount = await umount(
str(real_path),
Expand Down
9 changes: 7 additions & 2 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Final,
FrozenSet,
List,
Literal,
MutableMapping,
Optional,
Sequence,
Expand Down Expand Up @@ -500,7 +499,7 @@ def get_runner_mount(
type: MountTypes,
src: Union[str, Path],
target: Union[str, Path],
perm: Literal["ro", "rw"] = "ro",
perm: MountPermission = MountPermission.READ_ONLY,
opts: Optional[Mapping[str, Any]] = None,
) -> Mount:
return Mount(
Expand Down Expand Up @@ -1050,6 +1049,8 @@ async def _rollback_container_creation() -> None:
)

created_host_ports: Tuple[int, ...]
repl_in_port = 0
repl_out_port = 0
if container_network_info:
kernel_host = container_network_info.container_host
port_map = container_network_info.services
Expand Down Expand Up @@ -1100,6 +1101,8 @@ async def _rollback_container_creation() -> None:
)
sport["host_ports"] = created_host_ports

assert repl_in_port != 0, "repl_in_port should have bee assigned."
assert repl_out_port != 0, "repl_out_port should have bee assigned."
return {
"container_id": container._id,
"kernel_host": kernel_host,
Expand Down Expand Up @@ -1265,6 +1268,8 @@ def get_cgroup_path(self, controller: str, container_id: str) -> Path:
cgroup = f"docker/{container_id}"
case "systemd":
cgroup = f"system.slice/docker-{container_id}.scope"
case _:
raise ValueError(f"Unsupported cgroup driver: {driver!r}")
return mount_point / cgroup

async def load_resources(self) -> Mapping[DeviceName, AbstractComputePlugin]:
Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/agent/dummy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ImageRegistry,
KernelCreationConfig,
KernelId,
MountPermission,
MountTypes,
ResourceSlot,
ServicePort,
Expand Down Expand Up @@ -155,7 +156,7 @@ def get_runner_mount(
type: MountTypes,
src: str | Path,
target: str | Path,
perm: Literal["ro", "rw"] = "ro",
perm: MountPermission = MountPermission.READ_ONLY,
opts: Optional[Mapping[str, Any]] = None,
):
return Mount(MountTypes.BIND, Path(), Path())
Expand Down
17 changes: 9 additions & 8 deletions src/ai/backend/agent/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
List,
Literal,
Mapping,
NotRequired,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -146,18 +147,18 @@ class ResultRecord:
data: Optional[str] = None


class NextResult(TypedDict, total=False):
class NextResult(TypedDict):
runId: Optional[str]
status: ResultType
exitCode: Optional[int]
options: Optional[Mapping[str, Any]]
# v1
stdout: Optional[str]
stderr: Optional[str]
media: Optional[Sequence[Any]]
html: Optional[Sequence[Any]]
stdout: NotRequired[str]
stderr: NotRequired[str]
media: NotRequired[Sequence[Any]]
html: NotRequired[Sequence[Any]]
# v2
console: Optional[Sequence[Any]]
console: NotRequired[Sequence[Any]]


class AbstractKernel(UserDict, aobject, metaclass=ABCMeta):
Expand Down Expand Up @@ -816,9 +817,9 @@ def aggregate_console(
async def get_next_result(self, api_ver=2, flush_timeout=2.0) -> NextResult:
# Context: per API request
has_continuation = ClientFeatures.CONTINUATION in self.client_features
records = []
result: NextResult
try:
records = []
result: NextResult
assert self.output_queue is not None
with timeout(flush_timeout if has_continuation else None):
while True:
Expand Down
3 changes: 1 addition & 2 deletions src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Any,
FrozenSet,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Expand Down Expand Up @@ -412,7 +411,7 @@ def get_runner_mount(
type: MountTypes,
src: Union[str, Path],
target: Union[str, Path],
perm: Literal["ro", "rw"] = "ro",
perm: MountPermission = MountPermission.READ_ONLY,
opts: Optional[Mapping[str, Any]] = None,
) -> Mount:
return Mount(
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/common/enum_extension.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum

class StringSetFlag(enum.Flag):
class StringSetFlag(enum.StrEnum):
def __eq__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
def __or__( # type: ignore[override]
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class aobject(object):
"""

@classmethod
async def new(cls: Type[T_aobj], *args, **kwargs) -> T_aobj:
async def new(cls: Type[Self], *args, **kwargs) -> Self:
"""
We can do ``await SomeAObject(...)``, but this makes mypy
to complain about its return type with ``await`` statement.
Expand Down

0 comments on commit 3c8aa13

Please sign in to comment.