Skip to content

Commit

Permalink
fix: Regression of AgentSummary GQL resolver (#3045) (#3170)
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine authored Nov 28, 2024
1 parent 3ae0a32 commit 4cbdc0f
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 10 deletions.
1 change: 1 addition & 0 deletions changes/3045.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix regression of the `AgentSummary` resolver caused by an incorrect `batch_load_func` assignment.
14 changes: 7 additions & 7 deletions src/ai/backend/manager/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import enum
import uuid
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, cast
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Self, Sequence, cast

import graphene
import sqlalchemy as sa
Expand All @@ -17,7 +17,7 @@
from sqlalchemy.sql.expression import false, true

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import AgentId, BinarySize, HardwareMetadata, ResourceSlot
from ai.backend.common.types import AccessKey, AgentId, BinarySize, HardwareMetadata, ResourceSlot

from .base import (
Base,
Expand Down Expand Up @@ -528,7 +528,7 @@ def from_row(
cls,
ctx: GraphQueryContext,
row: Row,
) -> Agent:
) -> Self:
return cls(
id=row["id"],
status=row["status"].name,
Expand Down Expand Up @@ -561,11 +561,11 @@ async def batch_load(
graph_ctx: GraphQueryContext,
agent_ids: Sequence[AgentId],
*,
domain_name: str | None,
access_key: AccessKey,
domain_name: Optional[str] = None,
raw_status: Optional[str] = None,
scaling_group: Optional[str] = None,
access_key: str,
) -> Sequence[Agent | None]:
) -> Sequence[Optional[Self]]:
query = (
sa.select([agents])
.select_from(agents)
Expand Down Expand Up @@ -627,7 +627,7 @@ async def load_slice(
raw_status: Optional[str] = None,
filter: Optional[str] = None,
order: Optional[str] = None,
) -> Sequence[Agent]:
) -> Sequence[Self]:
query = sa.select([agents]).select_from(agents).limit(limit).offset(offset)
query = await _append_sgroup_from_clause(
graph_ctx, query, access_key, domain_name, scaling_group
Expand Down
9 changes: 8 additions & 1 deletion src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Awaitable,
Callable,
ClassVar,
Concatenate,
Coroutine,
Generic,
List,
Expand Down Expand Up @@ -712,7 +713,12 @@ def _get_func_key(
def get_loader_by_func(
self,
context: ContextT,
batch_load_func: Callable[[ContextT, Sequence[LoaderKeyT]], Awaitable[LoaderResultT]],
batch_load_func: Callable[
Concatenate[ContextT, Sequence[LoaderKeyT], ...], Awaitable[LoaderResultT]
],
# Using kwargs-only to prevent argument position confusion
# when DataLoader calls `batch_load_func(keys)` which is `partial(batch_load_func, **kwargs)(keys)`.
**kwargs,
) -> DataLoader:
key = self._get_func_key(batch_load_func)
loader = self.cache.get(key)
Expand All @@ -721,6 +727,7 @@ def get_loader_by_func(
functools.partial(
batch_load_func,
context,
**kwargs,
),
max_batch_size=128,
)
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,9 @@ async def resolve_agent_summary(
if ctx.local_config["manager"]["hide-agents"]:
raise ObjectNotFound(object_name="agent")

loader = ctx.dataloader_manager.get_loader(
loader = ctx.dataloader_manager.get_loader_by_func(
ctx,
"Agent",
AgentSummary.batch_load,
raw_status=None,
scaling_group=scaling_group,
domain_name=domain_name,
Expand Down

0 comments on commit 4cbdc0f

Please sign in to comment.