diff --git a/src/ai/backend/manager/scheduler/drf.py b/src/ai/backend/manager/scheduler/drf.py index c51ba4092c6..638402b91a6 100644 --- a/src/ai/backend/manager/scheduler/drf.py +++ b/src/ai/backend/manager/scheduler/drf.py @@ -16,9 +16,9 @@ SessionId, ) -from ..models import AgentRow, SessionRow +from ..models import AgentRow, KernelRow, SessionRow from ..models.scaling_group import ScalingGroupOpts -from .types import AbstractScheduler, KernelInfo +from .types import AbstractScheduler log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.scheduler")) @@ -83,7 +83,7 @@ def pick_session( async def _assign_agent( self, possible_agents: Sequence[AgentRow], - pending_session_or_kernel: SessionRow | KernelInfo, + pending_session_or_kernel: SessionRow | KernelRow, roundrobin_context: Optional[RoundRobinContext] = None, ) -> Optional[AgentId]: # If some predicate checks for a picked session fail, @@ -132,7 +132,7 @@ async def assign_agent_for_session( async def assign_agent_for_kernel( self, possible_agents: Sequence[AgentRow], - pending_kernel: KernelInfo, + pending_kernel: KernelRow, ) -> Optional[AgentId]: return await self._assign_agent( possible_agents, diff --git a/src/ai/backend/manager/scheduler/fifo.py b/src/ai/backend/manager/scheduler/fifo.py index e0a68c99c6a..f89b8f1226e 100644 --- a/src/ai/backend/manager/scheduler/fifo.py +++ b/src/ai/backend/manager/scheduler/fifo.py @@ -12,8 +12,8 @@ SessionId, ) -from ..models import AgentRow, SessionRow -from .types import AbstractScheduler, KernelInfo +from ..models import AgentRow, KernelRow, SessionRow +from .types import AbstractScheduler def get_num_extras(agent: AgentRow, requested_slots: ResourceSlot) -> int: @@ -78,7 +78,7 @@ async def assign_agent_for_session( async def assign_agent_for_kernel( self, agents: Sequence[AgentRow], - pending_kernel: KernelInfo, + pending_kernel: KernelRow, ) -> Optional[AgentId]: return await self.select_agent( agents, @@ -115,7 +115,7 @@ async def assign_agent_for_session( async def assign_agent_for_kernel( self, agents: Sequence[AgentRow], - pending_kernel: KernelInfo, + pending_kernel: KernelRow, ) -> Optional[AgentId]: return await self.select_agent( agents, diff --git a/src/ai/backend/manager/scheduler/types.py b/src/ai/backend/manager/scheduler/types.py index f8d73dc3be0..b3b0ed373df 100644 --- a/src/ai/backend/manager/scheduler/types.py +++ b/src/ai/backend/manager/scheduler/types.py @@ -460,7 +460,7 @@ async def assign_agent_for_session( async def assign_agent_for_kernel( self, possible_agents: Sequence[AgentRow], - pending_kernel: KernelInfo, + pending_kernel: KernelRow, ) -> Optional[AgentId]: """ Assign an agent for a kernel of the session. @@ -471,7 +471,7 @@ async def assign_agent_for_kernel( async def select_agent( self, possible_agents: Sequence[AgentRow], - pending_session_or_kernel: SessionRow | KernelInfo, + pending_session_or_kernel: SessionRow | KernelRow, use_num_extras: bool, roundrobin_context: Optional[RoundRobinContext] = None, ) -> Optional[AgentId]: @@ -498,15 +498,15 @@ async def select_agent( # Note that ROUNDROBIN is not working with the multi-node multi-container session. # It assumes the pending session type is single-node session. # Otherwise, it will use 'Dispersed' strategy as default strategy. - if ( - agent_selection_strategy == AgentSelectionStrategy.ROUNDROBIN - and type(pending_session_or_kernel) is KernelInfo + + if agent_selection_strategy == AgentSelectionStrategy.ROUNDROBIN and isinstance( + pending_session_or_kernel, KernelRow ): agent_selection_strategy = AgentSelectionStrategy.DISPERSED match agent_selection_strategy: case AgentSelectionStrategy.ROUNDROBIN: - assert type(pending_session_or_kernel) is SessionRow + assert isinstance(pending_session_or_kernel, SessionRow) assert roundrobin_context is not None sched_ctx = roundrobin_context.sched_ctx sgroup_name = roundrobin_context.sgroup_name