Skip to content

Commit

Permalink
chore(BA-440): Upgrade mypy to 1.14.1 and ruff to 0.8.5 (#3354)
Browse files Browse the repository at this point in the history
Backported-from: main (24.12)
Backported-to: 24.03
  • Loading branch information
achimnol committed Jan 2, 2025
1 parent 1abd0d0 commit 1c1ac19
Show file tree
Hide file tree
Showing 52 changed files with 581 additions and 634 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ asyncio_mode = "auto"
[tool.ruff]
line-length = 100
src = ["src"]

[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
Expand All @@ -82,13 +84,13 @@ select = [
]
ignore = ["E203","E731","E501"]

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["ai.backend"]
known-local-folder = ["src"]
known-third-party = ["alembic", "redis"]
split-on-trailing-comma = true

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"src/ai/backend/manager/config.py" = ["E402"]
"src/ai/backend/manager/models/alembic/env.py" = ["E402"]

Expand Down
9 changes: 6 additions & 3 deletions src/ai/backend/accelerator/cuda_open/nvidia.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import ctypes
import platform
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping, Sequence
from itertools import groupby
from operator import itemgetter
from typing import Any, MutableMapping, NamedTuple, Tuple, TypeAlias
from typing import Any, NamedTuple, TypeAlias, cast

# ref: https://developer.nvidia.com/cuda-toolkit-archive
TARGET_CUDA_VERSIONS = (
Expand Down Expand Up @@ -487,7 +488,7 @@ def load_library(cls):
return None

@classmethod
def get_version(cls) -> Tuple[int, int]:
def get_version(cls) -> tuple[int, int]:
if cls._version == (0, 0):
raw_ver = ctypes.c_int()
cls.invoke("cudaRuntimeGetVersion", ctypes.byref(raw_ver))
Expand All @@ -513,7 +514,9 @@ def get_device_props(cls, device_idx: int):
props_struct = cudaDeviceProp()
cls.invoke("cudaGetDeviceProperties", ctypes.byref(props_struct), device_idx)
props: MutableMapping[str, Any] = {
k: getattr(props_struct, k) for k, _ in props_struct._fields_
# Treat each field as two-tuple assuming that we don't have bit-fields
k: getattr(props_struct, k)
for k, _ in cast(Sequence[tuple[str, Any]], props_struct._fields_)
}
pci_bus_id = b" " * 16
cls.invoke("cudaDeviceGetPCIBusId", ctypes.c_char_p(pci_bus_id), 16, device_idx)
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,7 +1992,7 @@ async def create_kernel(
exposed_ports.append(cport)
for index, port in enumerate(ctx.kernel_config["allocated_host_ports"]):
service_ports.append({
"name": f"hostport{index+1}",
"name": f"hostport{index + 1}",
"protocol": ServicePortProtocols.INTERNAL,
"container_ports": (port,),
"host_ports": (port,),
Expand Down Expand Up @@ -2296,7 +2296,7 @@ async def load_model_definition(

if not model_definition_path:
raise AgentError(
f"Model definition file ({" or ".join(model_definition_candidates)}) does not exist under vFolder"
f"Model definition file ({' or '.join(model_definition_candidates)}) does not exist under vFolder"
f" {model_folder.name} (ID {model_folder.vfid})",
)
try:
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/docker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def get_container_version_and_status(self) -> Tuple[int, bool]:
raise
if c["Config"].get("Labels", {}).get("ai.backend.system", "0") != "1":
raise RuntimeError(
f"An existing container named \"{c['Name'].lstrip('/')}\" is not a system container"
f'An existing container named "{c["Name"].lstrip("/")}" is not a system container'
" spawned by Backend.AI. Please check and remove it."
)
return (
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ async def check_krunner_pv_status(self):
new_pv.label("backend.ai/backend-ai-scratch-volume", "hostPath")
else:
raise NotImplementedError(
f'Scratch type {self.local_config["container"]["scratch-type"]} is not'
f"Scratch type {self.local_config['container']['scratch-type']} is not"
" supported",
)

Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def get_container_version_and_status(self) -> Tuple[int, bool]:
raise
if c["Config"].get("Labels", {}).get("ai.backend.system", "0") != "1":
raise RuntimeError(
f"An existing container named \"{c['Name'].lstrip('/')}\" is not a system container"
f'An existing container named "{c["Name"].lstrip("/")}" is not a system container'
" spawned by Backend.AI. Please check and remove it."
)
return (
Expand Down
7 changes: 4 additions & 3 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import signal
import sys
from collections import OrderedDict, defaultdict
from ipaddress import _BaseAddress as BaseIPAddress
from ipaddress import ip_network
from ipaddress import IPv4Address, IPv6Address, ip_network
from pathlib import Path
from pprint import pformat, pprint
from typing import (
Expand Down Expand Up @@ -1049,7 +1048,9 @@ def main(
raise click.Abort()

rpc_host = cfg["agent"]["rpc-listen-addr"].host
if isinstance(rpc_host, BaseIPAddress) and (rpc_host.is_unspecified or rpc_host.is_link_local):
if isinstance(rpc_host, (IPv4Address, IPv6Address)) and (
rpc_host.is_unspecified or rpc_host.is_link_local
):
print(
"ConfigurationError: "
"Cannot use link-local or unspecified IP address as the RPC listening host.",
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def watch(
session_names = _fetch_session_names()
if not session_names:
if output == "json":
sys.stderr.write(f'{json.dumps({"ok": False, "reason": "No matching items."})}\n')
sys.stderr.write(f"{json.dumps({'ok': False, 'reason': 'No matching items.'})}\n")
else:
print_fail("No matching items.")
sys.exit(ExitCode.FAILURE)
Expand All @@ -1248,7 +1248,7 @@ def watch(
else:
if output == "json":
sys.stderr.write(
f'{json.dumps({"ok": False, "reason": "No matching items."})}\n'
f"{json.dumps({'ok': False, 'reason': 'No matching items.'})}\n"
)
else:
print_fail("No matching items.")
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/client/cli/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def request_download(name, filename):
with Session() as session:
try:
response = json.loads(session.VFolder(name).request_download(filename))
print_done(f'Download token: {response["token"]}')
print_done(f"Download token: {response['token']}")
except Exception as e:
print_error(e)
sys.exit(ExitCode.FAILURE)
Expand Down
15 changes: 6 additions & 9 deletions src/ai/backend/client/func/acl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import textwrap
from typing import Sequence

from ai.backend.client.output.fields import permission_fields
from ai.backend.client.output.types import FieldSpec

from ..output.fields import permission_fields
from ..output.types import FieldSpec
from ..session import api_session
from ..utils import dedent as _d
from .base import BaseFunction, api_function

__all__ = ("Permission",)
Expand All @@ -24,13 +23,11 @@ async def list(
:param fields: Additional permission query fields to fetch.
"""
query = textwrap.dedent(
"""\
query = _d("""
query {
vfolder_host_permissions {$fields}
vfolder_host_permissions { $fields }
}
"""
)
""")
query = query.replace("$fields", " ".join(f.field_ref for f in fields))
data = await api_session.get().Admin._query(query)
return data["vfolder_host_permissions"]
19 changes: 8 additions & 11 deletions src/ai/backend/client/func/agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

import textwrap
from typing import Optional, Sequence

from ai.backend.client.output.fields import agent_fields
from ai.backend.client.output.types import FieldSpec, PaginatedResult
from ai.backend.client.pagination import fetch_paginated_result
from ai.backend.client.request import Request
from ai.backend.client.session import api_session

from ..output.fields import agent_fields
from ..output.types import FieldSpec, PaginatedResult
from ..pagination import fetch_paginated_result
from ..request import Request
from ..session import api_session
from ..utils import dedent as _d
from .base import BaseFunction, api_function

__all__ = (
Expand Down Expand Up @@ -88,13 +87,11 @@ async def detail(
agent_id: str,
fields: Sequence[FieldSpec] = _default_detail_fields,
) -> Sequence[dict]:
query = textwrap.dedent(
"""\
query = _d("""
query($agent_id: String!) {
agent(agent_id: $agent_id) {$fields}
}
"""
)
""")
query = query.replace("$fields", " ".join(f.field_ref for f in fields))
variables = {"agent_id": agent_id}
data = await api_session.get().Admin._query(query, variables)
Expand Down
50 changes: 20 additions & 30 deletions src/ai/backend/client/func/domain.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import textwrap
from typing import Any, Iterable, Sequence

from ...cli.types import Undefined, undefined
from ..output.fields import domain_fields
from ..output.types import FieldSpec
from ..session import api_session
from ..types import set_if_set
from ..utils import dedent as _d
from .base import BaseFunction, api_function, resolve_fields

__all__ = ("Domain",)
Expand Down Expand Up @@ -56,13 +56,11 @@ async def list(
:param fields: Additional per-domain query fields to fetch.
"""
query = textwrap.dedent(
"""\
query = _d("""
query {
domains {$fields}
domains { $fields }
}
"""
)
""")
query = query.replace("$fields", " ".join(f.field_ref for f in fields))
data = await api_session.get().Admin._query(query)
return data["domains"]
Expand All @@ -75,18 +73,16 @@ async def detail(
fields: Sequence[FieldSpec] = _default_detail_fields,
) -> dict:
"""
Fetch information of a domain with name.
Retrieves the detail of a domain with name.
:param name: Name of the domain to fetch.
:param fields: Additional per-domain query fields to fetch.
"""
query = textwrap.dedent(
"""\
query = _d("""
query($name: String) {
domain(name: $name) {$fields}
domain(name: $name) { $fields }
}
"""
)
""")
query = query.replace("$fields", " ".join(f.field_ref for f in fields))
variables = {"name": name}
data = await api_session.get().Admin._query(query, variables)
Expand All @@ -108,17 +104,16 @@ async def create(
) -> dict:
"""
Creates a new domain with the given options.
You need an admin privilege for this operation.
"""
query = textwrap.dedent(
"""\
query = _d("""
mutation($name: String!, $input: DomainInput!) {
create_domain(name: $name, props: $input) {
ok msg domain {$fields}
ok msg domain { $fields }
}
}
"""
)
""")
resolved_fields = resolve_fields(fields, domain_fields, (domain_fields["name"],))
query = query.replace("$fields", " ".join(resolved_fields))
inputs = {
Expand Down Expand Up @@ -152,18 +147,17 @@ async def update(
fields: Iterable[FieldSpec | str] | None = None,
) -> dict:
"""
Update existing domain.
Updates an existing domain.
You need an admin privilege for this operation.
"""
query = textwrap.dedent(
"""\
query = _d("""
mutation($name: String!, $input: ModifyDomainInput!) {
modify_domain(name: $name, props: $input) {
ok msg
}
}
"""
)
""")
inputs: dict[str, Any] = {}
set_if_set(inputs, "name", new_name)
set_if_set(inputs, "description", description)
Expand All @@ -185,15 +179,13 @@ async def delete(cls, name: str):
"""
Inactivates an existing domain.
"""
query = textwrap.dedent(
"""\
query = _d("""
mutation($name: String!) {
delete_domain(name: $name) {
ok msg
}
}
"""
)
""")
variables = {"name": name}
data = await api_session.get().Admin._query(query, variables)
return data["delete_domain"]
Expand All @@ -204,15 +196,13 @@ async def purge(cls, name: str):
"""
Deletes an existing domain.
"""
query = textwrap.dedent(
"""\
query = _d("""
mutation($name: String!) {
purge_domain(name: $name) {
ok msg
}
}
"""
)
""")
variables = {"name": name}
data = await api_session.get().Admin._query(query, variables)
return data["purge_domain"]
Loading

0 comments on commit 1c1ac19

Please sign in to comment.