diff --git a/changes/1655.fix.md b/changes/1655.fix.md new file mode 100644 index 0000000000..366099ab22 --- /dev/null +++ b/changes/1655.fix.md @@ -0,0 +1 @@ +Refactor `PendingSession` Scheduler into `PendingSession` scheduler and `AgentSelector`, and replace `roundrobin` flag with `AgentSelectionStrategy.RoundRobin` policy. \ No newline at end of file diff --git a/src/ai/backend/common/config.py b/src/ai/backend/common/config.py index d54436862c..30a0513ec1 100644 --- a/src/ai/backend/common/config.py +++ b/src/ai/backend/common/config.py @@ -89,6 +89,14 @@ class BaseSchema(BaseModel): ), }).allow_extra("*") +# Used in Etcd as a global config. +# If `scalingGroup.scheduler_opts` contains an `agent_selector_config`, it will override this. +agent_selector_globalconfig_iv = t.Dict({}).allow_extra("*") + +# Used in `scalingGroup.scheduler_opts` as a per scaling_group config. +agent_selector_config_iv = t.Dict({}) | agent_selector_globalconfig_iv + + model_definition_iv = t.Dict({ t.Key("models"): t.List( t.Dict({ diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index 00a774a1d3..412b877913 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -27,6 +27,7 @@ NewType, NotRequired, Optional, + Self, Sequence, Tuple, Type, @@ -218,7 +219,9 @@ def check_typed_dict(value: Mapping[Any, Any], expected_type: Type[TD]) -> TD: SessionId = NewType("SessionId", uuid.UUID) KernelId = NewType("KernelId", uuid.UUID) ImageAlias = NewType("ImageAlias", str) +ArchName = NewType("ArchName", str) +ResourceGroupID = NewType("ResourceGroupID", str) AgentId = NewType("AgentId", str) DeviceName = NewType("DeviceName", str) DeviceId = NewType("DeviceId", str) @@ -840,7 +843,7 @@ def to_json(self) -> dict[str, Any]: raise NotImplementedError @classmethod - def from_json(cls, obj: Mapping[str, Any]) -> JSONSerializableMixin: + def from_json(cls, obj: Mapping[str, Any]) -> Self: return cls(**cls.as_trafaret().check(obj)) @classmethod @@ -985,7 +988,7 @@ def to_json(self) -> dict[str, Any]: return {host: [perm.value for perm in perms] for host, perms in self.items()} @classmethod - def from_json(cls, obj: Mapping[str, Any]) -> JSONSerializableMixin: + def from_json(cls, obj: Mapping[str, Any]) -> Self: return cls(**cls.as_trafaret().check(obj)) @classmethod @@ -1197,6 +1200,7 @@ class AcceleratorMetadata(TypedDict): class AgentSelectionStrategy(enum.StrEnum): DISPERSED = "dispersed" CONCENTRATED = "concentrated" + ROUNDROBIN = "roundrobin" # LEGACY chooses the largest agent (the sort key is a tuple of resource slots). LEGACY = "legacy" @@ -1215,29 +1219,6 @@ class VolumeMountableNodeType(enum.StrEnum): STORAGE_PROXY = enum.auto() -@dataclass -class RoundRobinState(JSONSerializableMixin): - schedulable_group_id: str - next_index: int - - def to_json(self) -> dict[str, Any]: - return dataclasses.asdict(self) - - @classmethod - def from_json(cls, obj: Mapping[str, Any]) -> RoundRobinState: - return cls(**cls.as_trafaret().check(obj)) - - @classmethod - def as_trafaret(cls) -> t.Trafaret: - return t.Dict({ - t.Key("schedulable_group_id"): t.String, - t.Key("next_index"): t.Int, - }) - - -# States of the round-robin scheduler for each resource group and architecture. -RoundRobinStates: TypeAlias = dict[str, dict[str, RoundRobinState]] - SSLContextType: TypeAlias = bool | Fingerprint | SSLContext diff --git a/src/ai/backend/common/validators.py b/src/ai/backend/common/validators.py index 7f26cbe324..5c3463a428 100644 --- a/src/ai/backend/common/validators.py +++ b/src/ai/backend/common/validators.py @@ -47,7 +47,6 @@ from .types import BinarySize as _BinarySize from .types import HostPortPair as _HostPortPair from .types import QuotaScopeID as _QuotaScopeID -from .types import RoundRobinState, RoundRobinStates from .types import VFolderID as _VFolderID __all__ = ( @@ -727,25 +726,3 @@ def check_and_return(self, value: Any) -> float: return 0 case _: self._failure(f"Value must be (float, tuple of float or None), not {type(value)}.") - - -class RoundRobinStatesJSONString(t.Trafaret): - def check_and_return(self, value: Any) -> RoundRobinStates: - try: - rr_states_dict: dict[str, dict[str, dict[str, Any]]] = json.loads(value) - except (KeyError, ValueError, json.decoder.JSONDecodeError): - self._failure( - f"Expected valid JSON string, got `{value}`. RoundRobinStatesJSONString should" - " be a valid JSON string", - value=value, - ) - - rr_states: RoundRobinStates = {} - for resource_group, arch_rr_states_dict in rr_states_dict.items(): - rr_states[resource_group] = {} - for arch, rr_state_dict in arch_rr_states_dict.items(): - if "next_index" not in rr_state_dict or "schedulable_group_id" not in rr_state_dict: - self._failure("Invalid roundrobin states") - rr_states[resource_group][arch] = RoundRobinState.from_json(rr_state_dict) - - return rr_states diff --git a/src/ai/backend/manager/BUILD b/src/ai/backend/manager/BUILD index 41aa23c244..aa735fa7a4 100644 --- a/src/ai/backend/manager/BUILD +++ b/src/ai/backend/manager/BUILD @@ -43,7 +43,12 @@ python_distribution( "fifo": "ai.backend.manager.scheduler.fifo:FIFOSlotScheduler", "lifo": "ai.backend.manager.scheduler.fifo:LIFOSlotScheduler", "drf": "ai.backend.manager.scheduler.drf:DRFScheduler", - "mof": "ai.backend.manager.scheduler.mof:MOFScheduler", + }, + "backendai_agentselector_v10": { + "legacy": "ai.backend.manager.scheduler.agent_selector:LegacyAgentSelector", + "roundrobin": "ai.backend.manager.scheduler.agent_selector:RoundRobinAgentSelector", + "concentrated": "ai.backend.manager.scheduler.agent_selector:ConcentratedAgentSelector", + "dispersed": "ai.backend.manager.scheduler.agent_selector:DispersedAgentSelector", }, "backendai_error_monitor_v20": { "intrinsic": "ai.backend.manager.plugin.error_monitor:ErrorMonitor", diff --git a/src/ai/backend/manager/config.py b/src/ai/backend/manager/config.py index 6f87342e96..2b4011208e 100644 --- a/src/ai/backend/manager/config.py +++ b/src/ai/backend/manager/config.py @@ -174,7 +174,6 @@ from __future__ import annotations -import json import logging import os import secrets @@ -212,7 +211,6 @@ from ai.backend.common.lock import EtcdLock, FileLock, RedisLock from ai.backend.common.types import ( HostPortPair, - RoundRobinState, SlotName, SlotTypes, current_resource_slots, @@ -350,6 +348,7 @@ "plugins": { "accelerator": {}, "scheduler": {}, + "agent_selector": {}, }, "watcher": { "token": None, @@ -439,6 +438,9 @@ def container_registry_serialize(v: dict[str, Any]) -> dict[str, str]: t.Key("scheduler", default=_config_defaults["plugins"]["scheduler"]): t.Mapping( t.String, t.Mapping(t.String, t.Any) ), + t.Key("agent_selector", default=_config_defaults["plugins"]["agent_selector"]): t.Mapping( + t.String, config.agent_selector_globalconfig_iv + ), }).allow_extra("*"), t.Key("network", default=_config_defaults["network"]): t.Dict({ t.Key("subnet", default=_config_defaults["network"]["subnet"]): t.Dict({ @@ -471,7 +473,6 @@ def container_registry_serialize(v: dict[str, Any]) -> dict[str, str]: ): session_hang_tolerance_iv, }, ).allow_extra("*"), - t.Key("roundrobin_states", default=None): t.Null | tx.RoundRobinStatesJSONString, }).allow_extra("*") _volume_defaults: dict[str, Any] = { @@ -847,38 +848,3 @@ def get_redis_url(self, db: int = 0) -> yarl.URL: self.data["redis"]["addr"][1] ).with_password(self.data["redis"]["password"]) / str(db) return url - - async def get_roundrobin_state( - self, resource_group_name: str, architecture: str - ) -> RoundRobinState | None: - """ - Return the roundrobin state for the given resource group and architecture. - If given resource group's roundrobin states or roundrobin state of the given architecture is not found, return None. - """ - if (rr_state_str := await self.get_raw("roundrobin_states")) is not None: - rr_states_dict: dict[str, dict[str, Any]] = json.loads(rr_state_str) - resource_group_rr_states_dict = rr_states_dict.get(resource_group_name, None) - - if resource_group_rr_states_dict is not None: - rr_state_dict = resource_group_rr_states_dict.get(architecture, None) - - if rr_state_dict is not None: - return RoundRobinState( - schedulable_group_id=rr_state_dict["schedulable_group_id"], - next_index=rr_state_dict["next_index"], - ) - - return None - - async def put_roundrobin_state( - self, resource_group_name: str, architecture: str, state: RoundRobinState - ) -> None: - """ - Update the roundrobin states using the given resource group and architecture key. - """ - rr_states_dict = json.loads(await self.get_raw("roundrobin_states") or "{}") - if resource_group_name not in rr_states_dict: - rr_states_dict[resource_group_name] = {} - - rr_states_dict[resource_group_name][architecture] = state.to_json() - await self.etcd.put("roundrobin_states", json.dumps(rr_states_dict)) diff --git a/src/ai/backend/manager/models/alembic/versions/c4b7ec740b36_migrate_roundrobin_strategy_to_agent_.py b/src/ai/backend/manager/models/alembic/versions/c4b7ec740b36_migrate_roundrobin_strategy_to_agent_.py new file mode 100644 index 0000000000..90870f2c71 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/c4b7ec740b36_migrate_roundrobin_strategy_to_agent_.py @@ -0,0 +1,59 @@ +"""migrate-roundrobin-strategy-to-agent-selector-config + +Revision ID: c4b7ec740b36 +Revises: 59a622c31820 +Create Date: 2024-09-17 00:31:31.379466 + +""" + +from alembic import op +from sqlalchemy.sql import text + +# revision identifiers, used by Alembic. +revision = "c4b7ec740b36" +down_revision = "59a622c31820" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.execute( + text(""" + UPDATE scaling_groups + SET scheduler_opts = + CASE + WHEN scheduler_opts ? 'roundrobin' AND scheduler_opts->>'roundrobin' = 'true' THEN + jsonb_set( + jsonb_set( + scheduler_opts - 'roundrobin', + '{agent_selection_strategy}', '"roundrobin"'::jsonb + ), + '{agent_selector_config}', '{}'::jsonb + ) + ELSE jsonb_set( + scheduler_opts - 'roundrobin', + '{agent_selector_config}', '{}'::jsonb + ) + END; + """) + ) + + +def downgrade() -> None: + op.execute( + text(""" + UPDATE scaling_groups + SET scheduler_opts = + CASE + WHEN scheduler_opts->>'agent_selection_strategy' = 'roundrobin' THEN + jsonb_set( + jsonb_set( + scheduler_opts - 'agent_selector_config', + '{agent_selection_strategy}', '"legacy"'::jsonb + ), + '{roundrobin}', 'true'::jsonb + ) + ELSE scheduler_opts - 'agent_selector_config' + END; + """) + ) diff --git a/src/ai/backend/manager/models/scaling_group.py b/src/ai/backend/manager/models/scaling_group.py index 345b00f908..05c67757f3 100644 --- a/src/ai/backend/manager/models/scaling_group.py +++ b/src/ai/backend/manager/models/scaling_group.py @@ -29,6 +29,7 @@ from sqlalchemy.sql.expression import true from ai.backend.common import validators as tx +from ai.backend.common.config import agent_selector_config_iv from ai.backend.common.types import ( AgentSelectionStrategy, JSONSerializableMixin, @@ -98,9 +99,12 @@ class ScalingGroupOpts(JSONSerializableMixin): ], ) pending_timeout: timedelta = timedelta(seconds=0) - config: Mapping[str, Any] = attr.Factory(dict) + config: Mapping[str, Any] = attr.field(factory=dict) + + # Scheduler has a dedicated database column to store its name, + # but agent selector configuration is stored as a part of the scheduler_opts column. agent_selection_strategy: AgentSelectionStrategy = AgentSelectionStrategy.DISPERSED - roundrobin: bool = False + agent_selector_config: Mapping[str, Any] = attr.field(factory=dict) def to_json(self) -> dict[str, Any]: return { @@ -108,7 +112,7 @@ def to_json(self) -> dict[str, Any]: "pending_timeout": self.pending_timeout.total_seconds(), "config": self.config, "agent_selection_strategy": self.agent_selection_strategy, - "roundrobin": self.roundrobin, + "agent_selector_config": self.agent_selector_config, } @classmethod @@ -127,7 +131,7 @@ def as_trafaret(cls) -> t.Trafaret: t.Key("agent_selection_strategy", default=AgentSelectionStrategy.DISPERSED): tx.Enum( AgentSelectionStrategy ), - t.Key("roundrobin", default=False): t.Bool(), + t.Key("agent_selector_config", default={}): agent_selector_config_iv, }).allow_extra("*") diff --git a/src/ai/backend/manager/scheduler/agent_selector.py b/src/ai/backend/manager/scheduler/agent_selector.py new file mode 100644 index 0000000000..87d2a07070 --- /dev/null +++ b/src/ai/backend/manager/scheduler/agent_selector.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import logging +import sys +from decimal import Decimal +from typing import Optional, Self, Sequence, override + +import pydantic +import trafaret as t + +from ai.backend.common.types import ( + AgentId, + ArchName, + ResourceSlot, +) + +from ..models import AgentRow, KernelRow, SessionRow +from .types import ( + AbstractAgentSelector, + NullAgentSelectorState, + ResourceGroupState, + T_ResourceGroupState, +) +from .utils import ( + get_requested_architecture, + sort_requested_slots_by_priority, +) + +log = logging.Logger(__spec__.name) + + +def get_num_extras(agent: AgentRow, requested_slots: ResourceSlot) -> int: + """ + Get the number of resource slots that: + 1) are requested but zero (unused), + 2) are available in the given agent. + + This is to prefer (or not) agents with additional unused slots, + depending on the selection strategy. + """ + unused_slot_keys = set() + for k, v in requested_slots.items(): + if v == Decimal(0): + unused_slot_keys.add(k) + num_extras = 0 + for k, v in agent.available_slots.items(): + if k in unused_slot_keys and v > Decimal(0): + num_extras += 1 + + return num_extras + + +class BaseAgentSelector(AbstractAgentSelector[T_ResourceGroupState]): + @property + @override + def config_iv(self) -> t.Dict: + return t.Dict({}).allow_extra("*") + + @override + @classmethod + def get_state_cls(cls) -> type[T_ResourceGroupState]: + raise NotImplementedError("must use a concrete subclass") + + def filter_agents( + self, + compatible_agents: Sequence[AgentRow], + pending_session_or_kernel: SessionRow | KernelRow, + ) -> Sequence[AgentRow]: + """ + Filter the agents by checking if it can host the picked session. + """ + return [ + agent + for agent in compatible_agents + if ( + agent.available_slots - agent.occupied_slots + >= pending_session_or_kernel.requested_slots + ) + ] + + +class LegacyAgentSelector(BaseAgentSelector[NullAgentSelectorState]): + @override + @classmethod + def get_state_cls(cls) -> type[NullAgentSelectorState]: + return NullAgentSelectorState + + @override + async def select_agent( + self, + agents: Sequence[AgentRow], + pending_session_or_kernel: SessionRow | KernelRow, + ) -> Optional[AgentId]: + agents = self.filter_agents(agents, pending_session_or_kernel) + if not agents: + return None + requested_slots = pending_session_or_kernel.requested_slots + resource_priorities = sort_requested_slots_by_priority( + requested_slots, self.agent_selection_resource_priority + ) + chosen_agent = max( + agents, + key=lambda agent: [ + -get_num_extras(agent, requested_slots), + *[agent.available_slots.get(key, -sys.maxsize) for key in resource_priorities], + ], + ) + return chosen_agent.id + + +class RoundRobinState(pydantic.BaseModel): + next_index: int = 0 + + +class RRAgentSelectorState(ResourceGroupState): + roundrobin_states: dict[ArchName, RoundRobinState] + + @override + @classmethod + def create_empty_state(cls) -> Self: + return cls(roundrobin_states={}) + + +class RoundRobinAgentSelector(BaseAgentSelector[RRAgentSelectorState]): + @override + @classmethod + def get_state_cls(cls) -> type[RRAgentSelectorState]: + return RRAgentSelectorState + + @override + async def select_agent( + self, + agents: Sequence[AgentRow], + pending_session_or_kernel: SessionRow | KernelRow, + ) -> Optional[AgentId]: + if isinstance(pending_session_or_kernel, KernelRow): + sgroup_name = pending_session_or_kernel.scaling_group + requested_architecture = ArchName(pending_session_or_kernel.architecture) + else: + sgroup_name = pending_session_or_kernel.scaling_group_name + requested_architecture = ArchName(get_requested_architecture(pending_session_or_kernel)) + + state = await self.state_store.load(sgroup_name, "agselector.roundrobin") + rr_state = state.roundrobin_states.get(requested_architecture, None) + + if rr_state is None: + start_idx = 0 + else: + # Since the number of agents may have changed, + # clamp the index to the current number of agents. + start_idx = rr_state.next_index % len(agents) + + # Make a consistent ordering of the agents. + agents = sorted(agents, key=lambda agent: agent.id) + chosen_agent = None + + for idx in range(len(agents)): + inspected_idx = (start_idx + idx) % len(agents) + if ( + agents[inspected_idx].available_slots - agents[inspected_idx].occupied_slots + >= pending_session_or_kernel.requested_slots + ): + chosen_agent = agents[inspected_idx] + state.roundrobin_states[requested_architecture] = RoundRobinState( + next_index=(inspected_idx + 1) % len(agents) + ) + break + await self.state_store.store(sgroup_name, "agselector.roundrobin", state) + + if not chosen_agent: + return None + + return chosen_agent.id + + +class ConcentratedAgentSelector(BaseAgentSelector[NullAgentSelectorState]): + @override + @classmethod + def get_state_cls(cls) -> type[NullAgentSelectorState]: + return NullAgentSelectorState + + @override + async def select_agent( + self, + agents: Sequence[AgentRow], + pending_session_or_kernel: SessionRow | KernelRow, + ) -> Optional[AgentId]: + agents = self.filter_agents(agents, pending_session_or_kernel) + if not agents: + return None + requested_slots = pending_session_or_kernel.requested_slots + resource_priorities = sort_requested_slots_by_priority( + requested_slots, self.agent_selection_resource_priority + ) + chosen_agent = min( + agents, + key=lambda agent: [ + get_num_extras(agent, requested_slots), + *[ + (agent.available_slots - agent.occupied_slots).get(key, sys.maxsize) + for key in resource_priorities + ], + ], + ) + return chosen_agent.id + + +class DispersedAgentSelector(BaseAgentSelector[NullAgentSelectorState]): + @override + @classmethod + def get_state_cls(cls) -> type[NullAgentSelectorState]: + return NullAgentSelectorState + + @override + async def select_agent( + self, + agents: Sequence[AgentRow], + pending_session_or_kernel: SessionRow | KernelRow, + ) -> Optional[AgentId]: + agents = self.filter_agents(agents, pending_session_or_kernel) + if not agents: + return None + requested_slots = pending_session_or_kernel.requested_slots + resource_priorities = sort_requested_slots_by_priority( + requested_slots, self.agent_selection_resource_priority + ) + chosen_agent = max( + agents, + key=lambda agent: [ + -get_num_extras(agent, requested_slots), + *[ + (agent.available_slots - agent.occupied_slots).get(key, -sys.maxsize) + for key in resource_priorities + ], + ], + ) + return chosen_agent.id diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index e4c7504a56..552fec2d35 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -1,11 +1,11 @@ from __future__ import annotations import asyncio -import hashlib import itertools import json import logging import uuid +from collections.abc import Awaitable, Mapping, Sequence from contextvars import ContextVar from datetime import datetime, timedelta from decimal import Decimal @@ -13,12 +13,7 @@ from typing import ( TYPE_CHECKING, Any, - Awaitable, - Final, - List, Optional, - Sequence, - Tuple, Union, ) @@ -55,10 +50,10 @@ from ai.backend.common.plugin.hook import PASSED, HookResult from ai.backend.common.types import ( AgentId, + AgentSelectionStrategy, ClusterMode, RedisConnectionInfo, ResourceSlot, - RoundRobinState, SessionId, aobject, ) @@ -106,12 +101,16 @@ check_user_resource_limit, ) from .types import ( + AbstractAgentSelector, AbstractScheduler, AgentAllocationContext, + DefaultResourceGroupStateStore, KernelAgentBinding, PendingSession, PredicateResult, + ResourceGroupState, SchedulingContext, + T_ResourceGroupState, ) if TYPE_CHECKING: @@ -120,25 +119,20 @@ __all__ = ( "load_scheduler", + "load_agent_selector", "SchedulerDispatcher", ) log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.scheduler")) _log_fmt: ContextVar[str] = ContextVar("_log_fmt") -_log_args: ContextVar[Tuple[Any, ...]] = ContextVar("_log_args") - -_key_schedule_prep_tasks: Final = "scheduler.preptasks" - - -def get_schedulable_group_id(agents: list[AgentRow]) -> str: - return hashlib.md5("#".join(list(map(lambda agent: agent.id, agents))).encode()).hexdigest() +_log_args: ContextVar[tuple[Any, ...]] = ContextVar("_log_args") def load_scheduler( name: str, sgroup_opts: ScalingGroupOpts, - scheduler_config: dict[str, Any], + scheduler_config: Mapping[str, Any], ) -> AbstractScheduler: entry_prefix = "backendai_scheduler_v10" for entrypoint in scan_entrypoints(entry_prefix): @@ -149,11 +143,40 @@ def load_scheduler( raise ImportError("Cannot load the scheduler plugin", name) -StartTaskArgs = Tuple[ - Tuple[Any, ...], +def load_agent_selector( + name: str, + sgroup_opts: ScalingGroupOpts, + selector_config: Mapping[str, Any], + agent_selection_resource_priority: list[str], + shared_config: SharedConfig, +) -> AbstractAgentSelector[ResourceGroupState]: + def create_agent_selector( + selector_cls: type[AbstractAgentSelector[T_ResourceGroupState]], + ) -> AbstractAgentSelector[T_ResourceGroupState]: + # An extra inner function to parametrize the generic type arguments + state_cls = selector_cls.get_state_cls() + state_store = DefaultResourceGroupStateStore(state_cls, shared_config) + return selector_cls( + sgroup_opts, + selector_config, + agent_selection_resource_priority, + state_store=state_store, + ) + + entry_prefix = "backendai_agentselector_v10" + for entrypoint in scan_entrypoints(entry_prefix): + if entrypoint.name == name: + log.debug('loading agent-selector plugin "{}" from {}', name, entrypoint.module) + selector_cls = entrypoint.load() + return create_agent_selector(selector_cls) + raise ImportError("Cannot load the agent-selector plugin", name) + + +StartTaskArgs = tuple[ + tuple[Any, ...], SchedulingContext, - Tuple[PendingSession, List[KernelAgentBinding]], - List[Tuple[str, Union[Exception, PredicateResult]]], + tuple[PendingSession, Sequence[KernelAgentBinding]], + list[tuple[str, Union[Exception, PredicateResult]]], ] @@ -353,7 +376,7 @@ async def _load_scheduler( self, db_sess: SASession, sgroup_name: str, - ) -> AbstractScheduler: + ) -> tuple[AbstractScheduler, AbstractAgentSelector]: query = sa.select(ScalingGroupRow.scheduler, ScalingGroupRow.scheduler_opts).where( ScalingGroupRow.name == sgroup_name ) @@ -361,13 +384,50 @@ async def _load_scheduler( row = result.first() scheduler_name = row.scheduler sgroup_opts: ScalingGroupOpts = row.scheduler_opts + match sgroup_opts.agent_selection_strategy: + # The names correspond to the entrypoint names (backendai_agentselector_v10). + case AgentSelectionStrategy.LEGACY: + agselector_name = "legacy" + case AgentSelectionStrategy.ROUNDROBIN: + agselector_name = "roundrobin" + case AgentSelectionStrategy.CONCENTRATED: + agselector_name = "concentrated" + case AgentSelectionStrategy.DISPERSED: + agselector_name = "dispersed" + case _ as unknown: + raise ValueError( + f"Unknown agent selection strategy: {unknown!r}. Possible values: {[*AgentSelectionStrategy.__members__.keys()]}" + ) + global_scheduler_opts = {} + global_agselector_opts = {} if self.shared_config["plugins"]["scheduler"]: global_scheduler_opts = self.shared_config["plugins"]["scheduler"].get( scheduler_name, {} ) - scheduler_specific_config = {**global_scheduler_opts, **sgroup_opts.config} - return load_scheduler(scheduler_name, sgroup_opts, scheduler_specific_config) + scheduler_config = {**global_scheduler_opts, **sgroup_opts.config} + if self.shared_config["plugins"]["agent_selector"]: + global_agselector_opts = self.shared_config["plugins"]["agent_selector"].get( + agselector_name, {} + ) + agselector_config = {**global_agselector_opts, **sgroup_opts.agent_selector_config} + agent_selection_resource_priority = self.local_config["manager"][ + "agent-selection-resource-priority" + ] + + scheduler = load_scheduler( + scheduler_name, + sgroup_opts, + scheduler_config, + ) + agent_selector = load_agent_selector( + agselector_name, + sgroup_opts, + agselector_config, + agent_selection_resource_priority, + self.shared_config, + ) + return scheduler, agent_selector async def _schedule_in_sgroup( self, @@ -414,7 +474,7 @@ async def _apply_cancellation( await db_sess.execute(query) async with self.db.begin_readonly_session() as db_sess: - scheduler = await self._load_scheduler(db_sess, sgroup_name) + scheduler, agent_selector = await self._load_scheduler(db_sess, sgroup_name) existing_sessions, pending_sessions, cancelled_sessions = await _list_managed_sessions( db_sess, sgroup_name, scheduler.sgroup_opts.pending_timeout ) @@ -478,10 +538,10 @@ async def _update(): _log_args.set(log_args) log.debug(log_fmt + "try-scheduling", *log_args) - async def _check_predicates() -> List[Tuple[str, Union[Exception, PredicateResult]]]: - check_results: List[Tuple[str, Union[Exception, PredicateResult]]] = [] + async def _check_predicates() -> list[tuple[str, Union[Exception, PredicateResult]]]: + check_results: list[tuple[str, Union[Exception, PredicateResult]]] = [] async with self.db.begin_session() as db_sess: - predicates: list[Tuple[str, Awaitable[PredicateResult]]] = [ + predicates: list[tuple[str, Awaitable[PredicateResult]]] = [ ( "reserved_time", check_reserved_batch_session(db_sess, sched_ctx, sess_ctx), @@ -662,30 +722,24 @@ async def _update_session_status_data() -> None: ), ) - agent_selection_resource_priority = self.local_config["manager"][ - "agent-selection-resource-priority" - ] - try: match schedulable_sess.cluster_mode: case ClusterMode.SINGLE_NODE: await self._schedule_single_node_session( sched_ctx, - scheduler, + agent_selector, sgroup_name, candidate_agents, schedulable_sess, - agent_selection_resource_priority, check_results, ) case ClusterMode.MULTI_NODE: await self._schedule_multi_node_session( sched_ctx, - scheduler, + agent_selector, sgroup_name, candidate_agents, schedulable_sess, - agent_selection_resource_priority, check_results, ) case _: @@ -693,6 +747,9 @@ async def _update_session_status_data() -> None: f"should not reach here; unknown cluster_mode: {schedulable_sess.cluster_mode}" ) continue + # For complex schedulers like DRF, they may need internal state updates + # based on the scheduling result. + scheduler.update_allocation(schedulable_sess) except InstanceNotAvailable as e: # Proceed to the next pending session and come back later. log.debug( @@ -742,12 +799,11 @@ def _check(cnt: str | None) -> bool: async def _schedule_single_node_session( self, sched_ctx: SchedulingContext, - scheduler: AbstractScheduler, + agent_selector: AbstractAgentSelector, sgroup_name: str, candidate_agents: Sequence[AgentRow], sess_ctx: SessionRow, - agent_selection_resource_priority: list[str], - check_results: List[Tuple[str, Union[Exception, PredicateResult]]], + check_results: list[tuple[str, Union[Exception, PredicateResult]]], ) -> None: """ Finds and assigns an agent having resources enough to host the entire session. @@ -826,75 +882,21 @@ async def _schedule_single_node_session( f"remaining: {available_slots[key] - occupied_slots[key]})." ), ) - else: - sorted_agents = sorted(compatible_candidate_agents, key=lambda agent: agent.id) - - if scheduler.sgroup_opts.roundrobin: - rr_state: ( - RoundRobinState | None - ) = await sched_ctx.registry.shared_config.get_roundrobin_state( - sgroup_name, requested_architecture - ) - - if rr_state is not None: - schedulable_group_id = get_schedulable_group_id(sorted_agents) - - if schedulable_group_id == rr_state.schedulable_group_id: - for i in range(len(sorted_agents)): - idx = (rr_state.next_index + i) % len(sorted_agents) - agent = sorted_agents[idx] - - if ( - agent.available_slots - agent.occupied_slots - > sess_ctx.requested_slots - ): - agent_id = agent.id - rr_state.next_index = (rr_state.next_index + i + 1) % len( - sorted_agents - ) - - await sched_ctx.registry.shared_config.put_roundrobin_state( - sgroup_name, requested_architecture, rr_state - ) - break - else: - # fallback to the default behavior instead of raising an error for reducing code complexity - pass - - if agent_id is None: - # Let the scheduler check the resource availability and decide the target agent - cand_agent_id = scheduler.assign_agent_for_session( - compatible_candidate_agents, - sess_ctx, - scheduler.sgroup_opts.agent_selection_strategy, - agent_selection_resource_priority, + # Let the agent selector decide the target agent + cand_agent_id = await agent_selector.assign_agent_for_session( + compatible_candidate_agents, + sess_ctx, + ) + if cand_agent_id is None: + raise InstanceNotAvailable( + extra_msg=( + "Could not find a contiguous resource region in any agent big" + f" enough to host the session (id: {sess_ctx.id}, resource group:" + f" {sess_ctx.scaling_group_name})" + ), ) - if cand_agent_id is None: - raise InstanceNotAvailable( - extra_msg=( - "Could not find a contiguous resource region in any agent big" - f" enough to host the session (id: {sess_ctx.id}, resource group:" - f" {sess_ctx.scaling_group_name})" - ), - ) - agent_id = cand_agent_id - - if scheduler.sgroup_opts.roundrobin: - await sched_ctx.registry.shared_config.put_roundrobin_state( - sgroup_name, - requested_architecture, - RoundRobinState( - schedulable_group_id=get_schedulable_group_id( - sorted_agents, - ), - next_index=[ - (idx + 1) % len(sorted_agents) - for idx, agent in enumerate(sorted_agents) - if agent.id == agent_id - ][0], - ), - ) + agent_id = cand_agent_id async with self.db.begin_session() as agent_db_sess: agent_alloc_ctx = await _reserve_agent( @@ -1017,12 +1019,11 @@ async def _finalize_scheduled() -> None: async def _schedule_multi_node_session( self, sched_ctx: SchedulingContext, - scheduler: AbstractScheduler, + agent_selector: AbstractAgentSelector, sgroup_name: str, candidate_agents: Sequence[AgentRow], sess_ctx: SessionRow, - agent_selection_resource_priority: list[str], - check_results: List[Tuple[str, Union[Exception, PredicateResult]]], + check_results: list[tuple[str, Union[Exception, PredicateResult]]], ) -> None: """ Finds and assigns agents having resources enough to host each kernel in the session. @@ -1030,7 +1031,8 @@ async def _schedule_multi_node_session( log_fmt = _log_fmt.get() log_args = _log_args.get() agent_query_extra_conds = None - kernel_agent_bindings: List[KernelAgentBinding] = [] + + kernel_agent_bindings: list[KernelAgentBinding] = [] async with self.db.begin_session() as agent_db_sess: # This outer transaction is rolled back when any exception occurs inside, # including scheduling failures of a kernel. @@ -1100,12 +1102,10 @@ async def _schedule_multi_node_session( " reached the hard limit of the number of containers." ), ) - # Let the scheduler check the resource availability and decide the target agent - agent_id = scheduler.assign_agent_for_kernel( + # Let the agent selector decide the target agent + agent_id = await agent_selector.assign_agent_for_kernel( available_candidate_agents, kernel, - scheduler.sgroup_opts.agent_selection_strategy, - agent_selection_resource_priority, ) if agent_id is None: raise InstanceNotAvailable( @@ -1710,7 +1710,7 @@ async def _list_managed_sessions( db_sess: SASession, sgroup_name: str, pending_timeout: timedelta, -) -> Tuple[List[SessionRow], List[SessionRow], List[SessionRow]]: +) -> tuple[list[SessionRow], list[SessionRow], list[SessionRow]]: """ Return three lists of sessions. first is a list of existing sessions, @@ -1719,9 +1719,9 @@ async def _list_managed_sessions( managed_sessions = await SessionRow.get_sgroup_managed_sessions(db_sess, sgroup_name) - candidates: List[SessionRow] = [] - cancelleds: List[SessionRow] = [] - existings: List[SessionRow] = [] + candidates: list[SessionRow] = [] + cancelleds: list[SessionRow] = [] + existings: list[SessionRow] = [] now = datetime.now(tzutc()) key_func = lambda s: (s.status.value, s.created_at) diff --git a/src/ai/backend/manager/scheduler/drf.py b/src/ai/backend/manager/scheduler/drf.py index 62e07a9108..5f3afb58cd 100644 --- a/src/ai/backend/manager/scheduler/drf.py +++ b/src/ai/backend/manager/scheduler/drf.py @@ -1,86 +1,45 @@ from __future__ import annotations import logging -import sys from collections import defaultdict +from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import Any, Dict, Mapping, Optional, Sequence, Set +from typing import Any, Optional, override import trafaret as t from ai.backend.common.types import ( AccessKey, - AgentId, - AgentSelectionStrategy, ResourceSlot, SessionId, ) from ai.backend.logging import BraceStyleAdapter -from ..models import AgentRow, SessionRow +from ..models import KernelRow, SessionRow from ..models.scaling_group import ScalingGroupOpts -from .types import AbstractScheduler, KernelInfo +from .types import AbstractScheduler log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.scheduler")) -def get_slot_index(slotname: str, agent_selection_resource_priority: list[str]) -> int: - try: - return agent_selection_resource_priority.index(slotname) - except ValueError: - return sys.maxsize - - -def key_by_remaining_slots( - agent: AgentRow, - requested_slots: ResourceSlot, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], -) -> list[Decimal]: - for requested_slot_key in sorted(requested_slots.data.keys(), reverse=True): - device_name = requested_slot_key.split(".")[0] - if ( - requested_slot_key not in agent_selection_resource_priority - and device_name in agent_selection_resource_priority - ): - agent_selection_resource_priority.insert( - agent_selection_resource_priority.index(device_name) + 1, requested_slot_key - ) - - resource_priorities = sorted( - requested_slots.data.keys(), - key=lambda item: get_slot_index(item, agent_selection_resource_priority), - ) - - remaining_slots = agent.available_slots - agent.occupied_slots - - # If the requested slot does not exist in the corresponding agent, - # the agent should not be selected, in this case it puts -math.inf for avoiding to being selected. - match agent_selection_strategy: - case AgentSelectionStrategy.LEGACY: - comparators = [ - agent.available_slots.get(key, -sys.maxsize) for key in resource_priorities - ] - case AgentSelectionStrategy.CONCENTRATED: - comparators = [-remaining_slots.get(key, sys.maxsize) for key in resource_priorities] - case AgentSelectionStrategy.DISPERSED | _: - comparators = [remaining_slots.get(key, -sys.maxsize) for key in resource_priorities] - - # Put back agents with more extra slot types - # (e.g., accelerators) - # Also put front agents with exactly required slot types - return comparators - - class DRFScheduler(AbstractScheduler): - config_iv = t.Dict({}).allow_extra("*") - per_user_dominant_share: Dict[AccessKey, Decimal] + per_user_dominant_share: dict[AccessKey, Decimal] total_capacity: ResourceSlot - def __init__(self, sgroup_opts: ScalingGroupOpts, config: Mapping[str, Any]) -> None: + def __init__( + self, + sgroup_opts: ScalingGroupOpts, + config: Mapping[str, Any], + ) -> None: super().__init__(sgroup_opts, config) self.per_user_dominant_share = defaultdict(lambda: Decimal(0)) + @property + @override + def config_iv(self) -> t.Dict: + return t.Dict({}).allow_extra("*") + + @override def pick_session( self, total_capacity: ResourceSlot, @@ -105,7 +64,7 @@ def pick_session( log.debug("per-user dominant share: {}", dict(self.per_user_dominant_share)) # Find who has the least dominant share among the pending session. - users_with_pending_session: Set[AccessKey] = { + users_with_pending_session: set[AccessKey] = { pending_sess.access_key for pending_sess in pending_sessions } if not users_with_pending_session: @@ -124,83 +83,27 @@ def pick_session( return None - def _assign_agent( + @override + def update_allocation( self, - agents: Sequence[AgentRow], - access_key: AccessKey, - requested_slots: ResourceSlot, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - # If some predicate checks for a picked session fail, - # this method is NOT called at all for the picked session. + scheduled_session_or_kernel: SessionRow | KernelRow, + ) -> None: # In such case, we just skip updating self.per_user_dominant_share state # and the scheduler dispatcher continues to pick another session within the same scaling group. - - possible_agents = [] - for agent in agents: - remaining_slots = agent.available_slots - agent.occupied_slots - if remaining_slots >= requested_slots: - possible_agents.append(agent) - - if possible_agents: - # We have one or more agents that can host the picked session. - - # Update the dominant share. - # This is required to use to the latest dominant share information - # when iterating over multiple pending sessions in a single scaling group. - dominant_share_from_request = Decimal(0) - for slot, value in requested_slots.items(): - self.total_capacity.sync_keys(requested_slots) - slot_cap = Decimal(self.total_capacity[slot]) - if slot_cap == 0: - continue - slot_share = Decimal(value) / slot_cap - if dominant_share_from_request < slot_share: - dominant_share_from_request = slot_share - if self.per_user_dominant_share[access_key] < dominant_share_from_request: - self.per_user_dominant_share[access_key] = dominant_share_from_request - - # Choose the agent. - chosen_agent = max( - possible_agents, - key=lambda agent: key_by_remaining_slots( - agent, - requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ), - ) - return chosen_agent.id - - return None - - def assign_agent_for_session( - self, - agents: Sequence[AgentRow], - pending_session: SessionRow, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - return self._assign_agent( - agents, - pending_session.access_key, - pending_session.requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ) - - def assign_agent_for_kernel( - self, - agents: Sequence[AgentRow], - pending_kernel: KernelInfo, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - return self._assign_agent( - agents, - pending_kernel.access_key, - pending_kernel.requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ) + access_key = scheduled_session_or_kernel.access_key + requested_slots = scheduled_session_or_kernel.requested_slots + + # Update the dominant share. + # This is required to use to the latest dominant share information + # when iterating over multiple pending sessions in a single scaling group. + dominant_share_from_request = Decimal(0) + for slot, value in requested_slots.items(): + self.total_capacity.sync_keys(requested_slots) + slot_cap = Decimal(self.total_capacity[slot]) + if slot_cap == 0: + continue + slot_share = Decimal(value) / slot_cap + if dominant_share_from_request < slot_share: + dominant_share_from_request = slot_share + if self.per_user_dominant_share[access_key] < dominant_share_from_request: + self.per_user_dominant_share[access_key] = dominant_share_from_request diff --git a/src/ai/backend/manager/scheduler/fifo.py b/src/ai/backend/manager/scheduler/fifo.py index f305b12a93..bde71a34fe 100644 --- a/src/ai/backend/manager/scheduler/fifo.py +++ b/src/ai/backend/manager/scheduler/fifo.py @@ -1,84 +1,31 @@ from __future__ import annotations -import sys -from decimal import Decimal -from typing import List, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import ( + Optional, + override, +) import trafaret as t from ai.backend.common.types import ( - AgentId, - AgentSelectionStrategy, ResourceSlot, SessionId, ) -from ..models import AgentRow, SessionRow -from .types import AbstractScheduler, KernelInfo - - -def get_slot_index(slotname: str, agent_selection_resource_priority: list[str]) -> int: - try: - return agent_selection_resource_priority.index(slotname) - except ValueError: - return sys.maxsize - - -def key_by_remaining_slots( - agent: AgentRow, - requested_slots: ResourceSlot, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], -) -> Tuple[int, ...]: - unused_slot_keys = set() - for k, v in requested_slots.items(): - if v == Decimal(0): - unused_slot_keys.add(k) - num_extras = 0 - for k, v in agent.available_slots.items(): - if k in unused_slot_keys and v > Decimal(0): - num_extras += 1 - - for requested_slot_key in sorted(requested_slots.data.keys(), reverse=True): - device_name = requested_slot_key.split(".")[0] - if ( - requested_slot_key not in agent_selection_resource_priority - and device_name in agent_selection_resource_priority - ): - agent_selection_resource_priority.insert( - agent_selection_resource_priority.index(device_name) + 1, requested_slot_key - ) - - resource_priorities = sorted( - requested_slots.data.keys(), - key=lambda item: get_slot_index(item, agent_selection_resource_priority), - ) - - remaining_slots = agent.available_slots - agent.occupied_slots - - # If the requested slot does not exist in the corresponding agent, - # the agent should not be selected, in this case it puts -math.inf for avoiding to being selected. - match agent_selection_strategy: - case AgentSelectionStrategy.LEGACY: - comparators = [ - agent.available_slots.get(key, -sys.maxsize) for key in resource_priorities - ] - case AgentSelectionStrategy.CONCENTRATED: - comparators = [-remaining_slots.get(key, sys.maxsize) for key in resource_priorities] - case AgentSelectionStrategy.DISPERSED | _: - comparators = [remaining_slots.get(key, -sys.maxsize) for key in resource_priorities] - - # Put back agents with more extra slot types - # (e.g., accelerators) - # Also put front agents with exactly required slot types - return (-num_extras, *comparators) +from ..models import KernelRow, SessionRow +from .types import AbstractScheduler class FIFOSlotScheduler(AbstractScheduler): - config_iv = t.Dict({ - t.Key("num_retries_to_skip", default=0): t.ToInt(gte=0), - }).allow_extra("*") - + @property + @override + def config_iv(self) -> t.Dict: + return t.Dict({ + t.Key("num_retries_to_skip", default=0): t.ToInt(gte=0), + }).allow_extra("*") + + @override def pick_session( self, total_capacity: ResourceSlot, @@ -86,9 +33,10 @@ def pick_session( existing_sessions: Sequence[SessionRow], ) -> Optional[SessionId]: local_pending_sessions = list(pending_sessions) - skipped_sessions: List[SessionRow] = [] + skipped_sessions: list[SessionRow] = [] max_retries = self.config["num_retries_to_skip"] while local_pending_sessions: + # This is the HoL blocking avoidance mechanism. # Just pick the first pending session, but skip it # if it has more than 3 failures. s = local_pending_sessions.pop(0) @@ -105,63 +53,21 @@ def pick_session( return skipped_sessions[0].id return None - def _assign_agent( + @override + def update_allocation( self, - agents: Sequence[AgentRow], - requested_slots: ResourceSlot, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - possible_agents = [] - for agent in agents: - remaining_slots = agent.available_slots - agent.occupied_slots - if remaining_slots >= requested_slots: - possible_agents.append(agent) - if possible_agents: - chosen_agent = max( - possible_agents, - key=lambda agent: key_by_remaining_slots( - agent, - requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ), - ) - return chosen_agent.id - return None - - def assign_agent_for_session( - self, - agents: Sequence[AgentRow], - pending_session: SessionRow, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - return self._assign_agent( - agents, - pending_session.requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ) - - def assign_agent_for_kernel( - self, - agents: Sequence[AgentRow], - pending_kernel: KernelInfo, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - return self._assign_agent( - agents, - pending_kernel.requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ) + scheduled_session_or_kernel: SessionRow | KernelRow, + ) -> None: + pass class LIFOSlotScheduler(AbstractScheduler): - config_iv = t.Dict({}).allow_extra("*") + @property + @override + def config_iv(self) -> t.Dict: + return t.Dict({}).allow_extra("*") + @override def pick_session( self, total_capacity: ResourceSlot, @@ -171,55 +77,9 @@ def pick_session( # Just pick the last pending session. return SessionId(pending_sessions[-1].id) - def _assign_agent( - self, - agents: Sequence[AgentRow], - requested_slots: ResourceSlot, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - possible_agents = [] - for agent in agents: - remaining_slots = agent.available_slots - agent.occupied_slots - if remaining_slots >= requested_slots: - possible_agents.append(agent) - if possible_agents: - chosen_agent = max( - possible_agents, - key=lambda agent: key_by_remaining_slots( - agent, - requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ), - ) - return chosen_agent.id - return None - - def assign_agent_for_session( - self, - agents: Sequence[AgentRow], - pending_session: SessionRow, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - return self._assign_agent( - agents, - pending_session.requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ) - - def assign_agent_for_kernel( + @override + def update_allocation( self, - agents: Sequence[AgentRow], - pending_kernel: KernelInfo, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - return self._assign_agent( - agents, - pending_kernel.requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ) + scheduled_session_or_kernel: SessionRow | KernelRow, + ) -> None: + pass diff --git a/src/ai/backend/manager/scheduler/mof.py b/src/ai/backend/manager/scheduler/mof.py deleted file mode 100644 index 381460c808..0000000000 --- a/src/ai/backend/manager/scheduler/mof.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from typing import Optional, Sequence - -import trafaret as t - -from ai.backend.common.types import ( - AccessKey, - AgentId, - AgentSelectionStrategy, - ResourceSlot, - SessionId, -) - -from ..models import AgentRow, SessionRow -from .types import AbstractScheduler, KernelInfo - - -class MOFScheduler(AbstractScheduler): - """Minimum Occupied slot First Scheduler""" - - config_iv = t.Dict({}).allow_extra("*") - - def pick_session( - self, - total_capacity: ResourceSlot, - pending_sessions: Sequence[SessionRow], - existing_sessions: Sequence[SessionRow], - ) -> Optional[SessionId]: - # Just pick the first pending session. - return SessionId(pending_sessions[0].id) - - def _assign_agent( - self, - agents: Sequence[AgentRow], - access_key: AccessKey, - requested_slots: ResourceSlot, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - # return min occupied slot agent or None - return next( - ( - one_agent.id - for one_agent in ( - sorted( - ( - agent - for agent in agents - if ((agent.available_slots - agent.occupied_slots) >= requested_slots) - ), - key=lambda agent: agent.occupied_slots, - ) - ) - ), - None, - ) - - def assign_agent_for_session( - self, - agents: Sequence[AgentRow], - pending_session: SessionRow, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - return self._assign_agent( - agents, - pending_session.access_key, - pending_session.requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ) - - def assign_agent_for_kernel( - self, - agents: Sequence[AgentRow], - pending_kernel: KernelInfo, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], - ) -> Optional[AgentId]: - return self._assign_agent( - agents, - pending_kernel.access_key, - pending_kernel.requested_slots, - agent_selection_strategy, - agent_selection_resource_priority, - ) diff --git a/src/ai/backend/manager/scheduler/types.py b/src/ai/backend/manager/scheduler/types.py index 5313ba6bec..bad6f72bb6 100644 --- a/src/ai/backend/manager/scheduler/types.py +++ b/src/ai/backend/manager/scheduler/types.py @@ -2,22 +2,28 @@ import logging import uuid -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from datetime import datetime from typing import ( Any, Dict, + Final, + Generic, List, Mapping, MutableMapping, MutableSequence, Optional, Protocol, + Self, Sequence, Set, + TypeVar, + override, ) import attrs +import pydantic import sqlalchemy as sa import trafaret as t from sqlalchemy.engine.row import Row @@ -28,9 +34,9 @@ from ai.backend.common.types import ( AccessKey, AgentId, - AgentSelectionStrategy, ClusterMode, KernelId, + ResourceGroupID, ResourceSlot, SessionId, SessionTypes, @@ -39,6 +45,7 @@ VFolderMount, ) from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.config import SharedConfig from ..defs import DEFAULT_ROLE from ..models import AgentRow, KernelRow, SessionRow, kernels, keypairs @@ -395,20 +402,31 @@ async def __call__( ) -> PredicateResult: ... -class AbstractScheduler(metaclass=ABCMeta): +class AbstractScheduler(ABC): """ - Interface for scheduling algorithms where the - ``schedule()`` method is a pure function. + The interface for scheduling algorithms to choose a pending session to schedule. """ sgroup_opts: ScalingGroupOpts # sgroup-specific config config: Mapping[str, Any] # scheduler-specific config - config_iv: t.Dict - def __init__(self, sgroup_opts: ScalingGroupOpts, config: Mapping[str, Any]) -> None: + def __init__( + self, + sgroup_opts: ScalingGroupOpts, + config: Mapping[str, Any], + ) -> None: self.sgroup_opts = sgroup_opts self.config = self.config_iv.check(config) + @property + @abstractmethod + def config_iv(self) -> t.Dict: + """ + The partial schema to extract configuration from the ``scaling_groups.scheduler_opts`` column. + The returned ``t.Dict`` should set ``.allow_extra("*")`` to coexist with the agent-selector config. + """ + raise NotImplementedError + @abstractmethod def pick_session( self, @@ -420,36 +438,243 @@ def pick_session( Pick a session to try schedule. This is where the queueing semantics is implemented such as prioritization. """ - return None + raise NotImplementedError + + def update_allocation( + self, + scheduled_session_or_kernel: SessionRow | KernelRow, + ) -> None: + """ + An optional method to update internal states of the scheduler after a session is allocated + and PASSED all predicate checks. + + This method is not called when any predicate check fails. + """ + pass + + +class ResourceGroupState(pydantic.BaseModel, ABC): + @classmethod + @abstractmethod + def create_empty_state(cls) -> Self: + raise NotImplementedError("must use a concrete subclass") + + +class NullAgentSelectorState(ResourceGroupState): + @override + @classmethod + def create_empty_state(cls) -> Self: + return cls() + + +T_ResourceGroupState = TypeVar("T_ResourceGroupState", bound=ResourceGroupState) + + +class AbstractAgentSelector(Generic[T_ResourceGroupState], ABC): + """ + The interface for agent-selection logic to choose one or more agents to map with the given + scheduled session. + """ + + sgroup_opts: ScalingGroupOpts # sgroup-specific config + config: Mapping[str, Any] # agent-selector-specific config + agent_selection_resource_priority: list[str] + state_store: AbstractResourceGroupStateStore[T_ResourceGroupState] + + def __init__( + self, + sgroup_opts: ScalingGroupOpts, + config: Mapping[str, Any], + agent_selection_resource_priority: list[str], + *, + state_store: AbstractResourceGroupStateStore[T_ResourceGroupState], + ) -> None: + self.sgroup_opts = sgroup_opts + self.config = self.config_iv.check(config) + self.agent_selection_resource_priority = agent_selection_resource_priority + self.state_store = state_store + + @property + @abstractmethod + def config_iv(self) -> t.Dict: + """ + The partial schema to extract configuration from the ``scaling_groups.scheduler_opts`` column. + The returned ``t.Dict`` should set ``.allow_extra("*")`` to coexist with the scheduler config. + """ + raise NotImplementedError + @classmethod @abstractmethod - def assign_agent_for_session( + def get_state_cls(cls) -> type[T_ResourceGroupState]: + raise NotImplementedError() + + async def assign_agent_for_session( self, - possible_agents: Sequence[AgentRow], + agents: Sequence[AgentRow], pending_session: SessionRow, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], ) -> Optional[AgentId]: """ - Assign an agent for the entire session, only considering the total requested - slots of the session. This is used for both single-container sessions and + Assign an agent for the entire (single-node) session, only considering + the total requested slots of the session. + This method is used for both single-container sessions and single-node multi-container sessions. In single-node multi-container sessions, all sub-containers are spawned by slicing the assigned agent's resource. + + The default implementation is to simply call ``select_agent()`` method. + """ + return await self.select_agent(agents, pending_session) + + async def assign_agent_for_kernel( + self, + agents: Sequence[AgentRow], + pending_kernel: KernelRow, + ) -> Optional[AgentId]: """ - return None + Assign an agent for a kernel of a multi-node multi-container session. + This may be called multiple times. + + The default implementation is to simply call ``select_agent()`` method. + """ + return await self.select_agent(agents, pending_kernel) @abstractmethod - def assign_agent_for_kernel( + async def select_agent( self, - possible_agents: Sequence[AgentRow], - pending_kernel: KernelInfo, - agent_selection_strategy: AgentSelectionStrategy, - agent_selection_resource_priority: list[str], + agents: Sequence[AgentRow], + pending_session_or_kernel: SessionRow | KernelRow, ) -> Optional[AgentId]: """ - Assign an agent for a kernel of the session. - This may be called multiple times for multi-node multi-container sessions. + Select an agent for the pending session or kernel. """ - return None + raise NotImplementedError + + +class AbstractResourceGroupStateStore(Generic[T_ResourceGroupState], ABC): + """ + Store and load the state of the pending session scheduler and agent selector for each resource group. + """ + + def __init__(self, state_cls: type[T_ResourceGroupState]) -> None: + self.state_cls = state_cls + + @abstractmethod + async def load( + self, + resource_group_name: ResourceGroupID, + state_name: str, + ) -> T_ResourceGroupState: + raise NotImplementedError + + @abstractmethod + async def store( + self, + resource_group_name: ResourceGroupID, + state_name: str, + state_value: T_ResourceGroupState, + ) -> None: + raise NotImplementedError + + @abstractmethod + async def reset( + self, + resource_group_name: ResourceGroupID, + state_name: str, + ) -> None: + raise NotImplementedError + + +class DefaultResourceGroupStateStore(AbstractResourceGroupStateStore[T_ResourceGroupState]): + """ + The default AgentSelector state store using the etcd + """ + + base_key: Final[str] = "resource-group-states" + + def __init__(self, state_cls: type[T_ResourceGroupState], shared_config: SharedConfig) -> None: + super().__init__(state_cls) + self.shared_config = shared_config + + @override + async def load( + self, + resource_group_name: ResourceGroupID, + state_name: str, + ) -> T_ResourceGroupState: + log.debug("{}: load agselector state for {}", type(self).__qualname__, resource_group_name) + if ( + raw_agent_selector_state := await self.shared_config.get_raw( + f"{self.base_key}/{resource_group_name}/{state_name}", + ) + ) is not None: + return self.state_cls.model_validate_json(raw_agent_selector_state) + return self.state_cls.create_empty_state() + + @override + async def store( + self, + resource_group_name: ResourceGroupID, + state_name: str, + state_value: T_ResourceGroupState, + ) -> None: + log.debug("{}: store agselector state for {}", type(self).__qualname__, resource_group_name) + await self.shared_config.etcd.put( + f"{self.base_key}/{resource_group_name}/{state_name}", + state_value.model_dump_json(), + ) + + @override + async def reset( + self, + resource_group_name: ResourceGroupID, + state_name: str, + ) -> None: + log.debug("{}: reset agselector state for {}", type(self).__qualname__, resource_group_name) + await self.shared_config.etcd.delete_prefix( + f"{self.base_key}/{resource_group_name}/{state_name}", + ) + + +class InMemoryResourceGroupStateStore(AbstractResourceGroupStateStore[T_ResourceGroupState]): + """ + An in-memory AgentSelector state store to use in test codes. + This cannot be used for the actual dispatcher loop since the state is NOT preserved whenever the + Scheduler and AgentSelector instances are recreated. + """ + + states: dict[tuple[ResourceGroupID, str], T_ResourceGroupState] + + def __init__(self, state_cls: type[T_ResourceGroupState]) -> None: + super().__init__(state_cls) + self.states = {} + + @override + async def load( + self, + resource_group_name: ResourceGroupID, + state_name: str, + ) -> T_ResourceGroupState: + log.debug("{}: load agselector state for {}", type(self).__qualname__, resource_group_name) + return self.states.get( + (resource_group_name, state_name), self.state_cls.create_empty_state() + ) + + @override + async def store( + self, + resource_group_name: ResourceGroupID, + state_name: str, + state_value: T_ResourceGroupState, + ) -> None: + log.debug("{}: store agselector state for {}", type(self).__qualname__, resource_group_name) + self.states[(resource_group_name, state_name)] = state_value + + @override + async def reset( + self, + resource_group_name: ResourceGroupID, + state_name: str, + ) -> None: + log.debug("{}: reset agselector state for {}", type(self).__qualname__, resource_group_name) + del self.states[(resource_group_name, state_name)] diff --git a/src/ai/backend/manager/scheduler/utils.py b/src/ai/backend/manager/scheduler/utils.py new file mode 100644 index 0000000000..4d77b5168f --- /dev/null +++ b/src/ai/backend/manager/scheduler/utils.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + +from ai.backend.common.types import ( + ResourceSlot, +) + +from ..api.exceptions import GenericBadRequest + +if TYPE_CHECKING: + from ..models.session import SessionRow + + +def get_slot_index(slotname: str, agent_selection_resource_priority: list[str]) -> int: + try: + return agent_selection_resource_priority.index(slotname) + except ValueError: + return sys.maxsize + + +def sort_requested_slots_by_priority( + requested_slots: ResourceSlot, agent_selection_resource_priority: list[str] +) -> list[str]: + """ + Sort ``requested_slots``'s keys by the given ``agent_selection_resource_priority`` list. + """ + + for requested_slot_key in sorted(requested_slots.data.keys(), reverse=True): + device_name = requested_slot_key.split(".")[0] + if ( + requested_slot_key not in agent_selection_resource_priority + and device_name in agent_selection_resource_priority + ): + agent_selection_resource_priority.insert( + agent_selection_resource_priority.index(device_name) + 1, requested_slot_key + ) + + return sorted( + requested_slots.data.keys(), + key=lambda item: get_slot_index(item, agent_selection_resource_priority), + ) + + +def get_requested_architecture(sess_ctx: SessionRow) -> str: + requested_architectures = set(k.architecture for k in sess_ctx.kernels) + if len(requested_architectures) > 1: + raise GenericBadRequest( + "Cannot assign multiple kernels with different architectures' single node session", + ) + return requested_architectures.pop() diff --git a/tests/manager/test_scheduler.py b/tests/manager/test_scheduler.py index cafa6b72f3..57d00ce94f 100644 --- a/tests/manager/test_scheduler.py +++ b/tests/manager/test_scheduler.py @@ -1,16 +1,18 @@ from __future__ import annotations import secrets +from collections.abc import Mapping, Sequence from datetime import datetime, timedelta from decimal import Decimal from pprint import pprint -from typing import Any, Mapping, Sequence +from typing import Any, Generator from unittest import mock from unittest.mock import AsyncMock, MagicMock from uuid import UUID, uuid4 import attrs import pytest +import pytest_mock import trafaret as t from dateutil.parser import parse as dtparse from dateutil.tz import tzutc @@ -33,6 +35,10 @@ from ai.backend.manager.models.scaling_group import ScalingGroupOpts from ai.backend.manager.models.session import SessionRow, SessionStatus from ai.backend.manager.registry import AgentRegistry +from ai.backend.manager.scheduler.agent_selector import ( + DispersedAgentSelector, + RoundRobinAgentSelector, +) from ai.backend.manager.scheduler.dispatcher import ( SchedulerDispatcher, _list_managed_sessions, @@ -40,23 +46,31 @@ ) from ai.backend.manager.scheduler.drf import DRFScheduler from ai.backend.manager.scheduler.fifo import FIFOSlotScheduler, LIFOSlotScheduler -from ai.backend.manager.scheduler.mof import MOFScheduler from ai.backend.manager.scheduler.predicates import check_reserved_batch_session +from ai.backend.manager.scheduler.types import InMemoryResourceGroupStateStore ARCH_FOR_TEST = "x86_64" agent_selection_resource_priority = ["cuda", "rocm", "tpu", "cpu", "mem"] -def test_load_intrinsic(): +def test_load_intrinsic() -> None: default_sgroup_opts = ScalingGroupOpts() - assert isinstance(load_scheduler("fifo", default_sgroup_opts, {}), FIFOSlotScheduler) - assert isinstance(load_scheduler("lifo", default_sgroup_opts, {}), LIFOSlotScheduler) - assert isinstance(load_scheduler("drf", default_sgroup_opts, {}), DRFScheduler) - assert isinstance(load_scheduler("mof", default_sgroup_opts, {}), MOFScheduler) + assert isinstance( + load_scheduler("fifo", default_sgroup_opts, {}), + FIFOSlotScheduler, + ) + assert isinstance( + load_scheduler("lifo", default_sgroup_opts, {}), + LIFOSlotScheduler, + ) + assert isinstance( + load_scheduler("drf", default_sgroup_opts, {}), + DRFScheduler, + ) -def test_scheduler_configs(): +def test_scheduler_configs() -> None: example_sgroup_opts = ScalingGroupOpts( # already processed by column trafaret allowed_session_types=[SessionTypes.BATCH], pending_timeout=timedelta(seconds=86400 * 2), @@ -73,8 +87,12 @@ def test_scheduler_configs(): "num_retries_to_skip": 5, } with pytest.raises(t.DataError): - example_sgroup_opts.config["num_retries_to_skip"] = -1 # invalid value - scheduler = load_scheduler("fifo", example_sgroup_opts, example_sgroup_opts.config) + example_sgroup_opts.config["num_retries_to_skip"] = -1 # type: ignore + scheduler = load_scheduler( + "fifo", + example_sgroup_opts, + example_sgroup_opts.config, + ) example_group_id = uuid4() @@ -85,7 +103,7 @@ def test_scheduler_configs(): @pytest.fixture -def example_agents(): +def example_agents() -> Sequence[AgentRow]: return [ AgentRow( id=AgentId("i-001"), @@ -127,7 +145,106 @@ def example_agents(): @pytest.fixture -def example_mixed_agents(): +def example_agents_many() -> Sequence[AgentRow]: + return [ + AgentRow( + id=AgentId("i-001"), + addr="10.0.1.1:6001", + architecture=ARCH_FOR_TEST, + scaling_group=example_sgroup_name1, + available_slots=ResourceSlot({ + "cpu": Decimal("8"), + "mem": Decimal("4096"), + "cuda.shares": Decimal("4.0"), + }), + occupied_slots=ResourceSlot({ + "cpu": Decimal("0"), + "mem": Decimal("0"), + "cuda.shares": Decimal("0"), + }), + ), + AgentRow( + id=AgentId("i-002"), + addr="10.0.2.1:6001", + architecture=ARCH_FOR_TEST, + scaling_group=example_sgroup_name2, + available_slots=ResourceSlot({ + "cpu": Decimal("4"), + "mem": Decimal("2048"), + "cuda.shares": Decimal("1.0"), + }), + occupied_slots=ResourceSlot({ + "cpu": Decimal("0"), + "mem": Decimal("0"), + "cuda.shares": Decimal("0"), + }), + ), + AgentRow( + id=AgentId("i-003"), + addr="10.0.3.1:6001", + architecture=ARCH_FOR_TEST, + scaling_group=example_sgroup_name2, + available_slots=ResourceSlot({ + "cpu": Decimal("2"), + "mem": Decimal("1024"), + "cuda.shares": Decimal("1.0"), + }), + occupied_slots=ResourceSlot({ + "cpu": Decimal("0"), + "mem": Decimal("0"), + "cuda.shares": Decimal("0"), + }), + ), + AgentRow( + id=AgentId("i-004"), + addr="10.0.4.1:6001", + architecture=ARCH_FOR_TEST, + scaling_group=example_sgroup_name2, + available_slots=ResourceSlot({ + "cpu": Decimal("1"), + "mem": Decimal("512"), + "cuda.shares": Decimal("0.5"), + }), + occupied_slots=ResourceSlot({ + "cpu": Decimal("0"), + "mem": Decimal("0"), + "cuda.shares": Decimal("0"), + }), + ), + ] + + +@pytest.fixture +def example_agents_multi_homogeneous( + request: pytest.FixtureRequest, +) -> Generator[Sequence[AgentRow], None, None]: + repeat = request.param.get("repeat", 10) + + yield [ + AgentRow( + id=AgentId(f"i-{idx:03d}"), + addr=f"10.0.1.{idx}:6001", + architecture=ARCH_FOR_TEST, + scaling_group=example_sgroup_name1, + available_slots=ResourceSlot({ + "cpu": Decimal("4.0"), + "mem": Decimal("4096"), + "cuda.shares": Decimal("4.0"), + "rocm.devices": Decimal("2"), + }), + occupied_slots=ResourceSlot({ + "cpu": Decimal("0"), + "mem": Decimal("0"), + "cuda.shares": Decimal("0"), + "rocm.devices": Decimal("0"), + }), + ) + for idx in range(repeat) + ] + + +@pytest.fixture +def example_mixed_agents() -> Sequence[AgentRow]: return [ AgentRow( id=AgentId("i-gpu"), @@ -165,7 +282,7 @@ def example_mixed_agents(): @pytest.fixture -def example_agents_first_one_assigned(): +def example_agents_first_one_assigned() -> Sequence[AgentRow]: return [ AgentRow( id=AgentId("i-001"), @@ -207,7 +324,7 @@ def example_agents_first_one_assigned(): @pytest.fixture -def example_agents_no_valid(): +def example_agents_no_valid() -> Sequence[AgentRow]: return [ AgentRow( id=AgentId("i-001"), @@ -250,26 +367,26 @@ def example_agents_no_valid(): @attrs.define(auto_attribs=True, slots=True) class SessionKernelIdPair: - session_id: UUID + session_id: SessionId kernel_ids: Sequence[KernelId] cancelled_session_ids = [ - UUID("251907d9-1290-4126-bc6c-000000000999"), + SessionId(UUID("251907d9-1290-4126-bc6c-000000000999")), ] pending_session_kernel_ids = [ SessionKernelIdPair( - session_id=UUID("251907d9-1290-4126-bc6c-000000000100"), + session_id=SessionId(UUID("251907d9-1290-4126-bc6c-000000000100")), kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-000000000100"))], ), SessionKernelIdPair( - session_id=UUID("251907d9-1290-4126-bc6c-000000000200"), + session_id=SessionId(UUID("251907d9-1290-4126-bc6c-000000000200")), kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-000000000200"))], ), SessionKernelIdPair( # single-node mode multi-container session - session_id=UUID("251907d9-1290-4126-bc6c-000000000300"), + session_id=SessionId(UUID("251907d9-1290-4126-bc6c-000000000300")), kernel_ids=[ KernelId(UUID("251907d9-1290-4126-bc6c-000000000300")), KernelId(UUID("251907d9-1290-4126-bc6c-000000000301")), @@ -277,26 +394,26 @@ class SessionKernelIdPair: ], ), SessionKernelIdPair( - session_id=UUID("251907d9-1290-4126-bc6c-000000000400"), + session_id=SessionId(UUID("251907d9-1290-4126-bc6c-000000000400")), kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-000000000400"))], ), ] existing_session_kernel_ids = [ SessionKernelIdPair( - session_id=UUID("251907d9-1290-4126-bc6c-100000000100"), + session_id=SessionId(UUID("251907d9-1290-4126-bc6c-100000000100")), kernel_ids=[ KernelId(UUID("251907d9-1290-4126-bc6c-100000000100")), KernelId(UUID("251907d9-1290-4126-bc6c-100000000101")), ], ), SessionKernelIdPair( - session_id=UUID("251907d9-1290-4126-bc6c-100000000200"), + session_id=SessionId(UUID("251907d9-1290-4126-bc6c-100000000200")), kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-100000000200"))], ), SessionKernelIdPair( # single-node mode multi-container session - session_id=UUID("251907d9-1290-4126-bc6c-100000000300"), + session_id=SessionId(UUID("251907d9-1290-4126-bc6c-100000000300")), kernel_ids=[KernelId(UUID("251907d9-1290-4126-bc6c-100000000300"))], ), ] @@ -327,7 +444,58 @@ class SessionKernelIdPair: @pytest.fixture -def example_cancelled_sessions(): +def example_homogeneous_pending_sessions( + request: pytest.FixtureRequest, +) -> Generator[Sequence[SessionRow], None, None]: + repeat = request.param.get("repeat", 10) + yield [ + SessionRow( + kernels=[ + KernelRow( + id=pending_session_kernel_ids[2].kernel_ids[0], + session_id=pending_session_kernel_ids[2].session_id, + access_key="dummy-access-key", + agent=None, + agent_addr=None, + cluster_role=DEFAULT_ROLE, + cluster_idx=1, + local_rank=0, + cluster_hostname=f"{DEFAULT_ROLE}0", + architecture=common_image_ref.architecture, + registry=common_image_ref.registry, + image=common_image_ref.name, + requested_slots=ResourceSlot({ + "cpu": Decimal("2.0"), + "mem": Decimal("1024"), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse("2021-12-01T23:59:59+00:00"), + ), + ], + access_key=AccessKey("user01"), + id=UUID(f"251907d9-1290-4126-bc6c-{idx:012x}"), + creation_id=f"{idx:012x}", + name=f"session-{idx}", + session_type=SessionTypes.BATCH, + status=SessionStatus.PENDING, + cluster_mode="single-node", + cluster_size=1, + scaling_group_name=example_sgroup_name1, + requested_slots=ResourceSlot({ + "cpu": Decimal("2.0"), + "mem": Decimal("1024"), + }), + target_sgroup_names=[], + **_common_dummy_for_pending_session, + created_at=dtparse("2021-12-28T23:59:59+00:00"), + ) + for idx in range(repeat) + ] + + +@pytest.fixture +def example_cancelled_sessions() -> Sequence[SessionRow]: return [ SessionRow( access_key=AccessKey("user01"), @@ -352,8 +520,54 @@ def example_cancelled_sessions(): ] +def create_pending_session( + session_id: SessionId, kernel_id: KernelId, requested_slots: ResourceSlot +) -> SessionRow: + """Create a simple single-kernel pending session.""" + return SessionRow( + kernels=[ + KernelRow( + id=session_id, + session_id=kernel_id, + access_key="dummy-access-key", + agent=None, + agent_addr=None, + cluster_role=DEFAULT_ROLE, + cluster_idx=1, + local_rank=0, + cluster_hostname=f"{DEFAULT_ROLE}0", + architecture=common_image_ref.architecture, + registry=common_image_ref.registry, + image=common_image_ref.name, + requested_slots=ResourceSlot({ + "cpu": Decimal("2.0"), + "mem": Decimal("1024"), + "cuda.shares": Decimal("0"), + "rocm.devices": Decimal("1"), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse("2021-12-28T23:59:59+00:00"), + ), + ], + access_key=AccessKey("user01"), + id=pending_session_kernel_ids[0].session_id, + creation_id="aaa100", + name="eps01", + session_type=SessionTypes.BATCH, + status=SessionStatus.PENDING, + cluster_mode="single-node", + cluster_size=1, + scaling_group_name=example_sgroup_name1, + requested_slots=requested_slots, + target_sgroup_names=[], + **_common_dummy_for_pending_session, + created_at=dtparse("2021-12-28T23:59:59+00:00"), + ) + + @pytest.fixture -def example_pending_sessions(): +def example_pending_sessions() -> Sequence[SessionRow]: # lower indicies are enqueued first. return [ SessionRow( # rocm @@ -542,7 +756,7 @@ def example_pending_sessions(): @pytest.fixture -def example_existing_sessions(): +def example_existing_sessions() -> Sequence[SessionRow]: return [ SessionRow( kernels=[ @@ -696,7 +910,7 @@ def example_existing_sessions(): ] -def _find_and_pop_picked_session(pending_sessions, picked_session_id): +def _find_and_pop_picked_session(pending_sessions, picked_session_id) -> SessionRow: for picked_idx, pending_sess in enumerate(pending_sessions): if pending_sess.id == picked_session_id: break @@ -706,10 +920,32 @@ def _find_and_pop_picked_session(pending_sessions, picked_session_id): return pending_sessions.pop(picked_idx) -def test_fifo_scheduler(example_agents, example_pending_sessions, example_existing_sessions): +def _update_agent_assignment( + agents: list[AgentRow], + picked_agent_id: AgentId, + occupied_slots: ResourceSlot, +) -> None: + for ag in agents: + if ag.id == picked_agent_id: + ag.occupied_slots += occupied_slots + + +@pytest.mark.asyncio +async def test_fifo_scheduler( + example_agents: Sequence[AgentRow], + example_pending_sessions: Sequence[SessionRow], + example_existing_sessions: Sequence[SessionRow], +) -> None: scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {}) + agstate_cls = DispersedAgentSelector.get_state_cls() + agselector = DispersedAgentSelector( + ScalingGroupOpts(), + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) picked_session_id = scheduler.pick_session( - example_total_capacity, + sum((ag.available_slots for ag in example_agents), start=ResourceSlot()), example_pending_sessions, example_existing_sessions, ) @@ -718,19 +954,29 @@ def test_fifo_scheduler(example_agents, example_pending_sessions, example_existi example_pending_sessions, picked_session_id, ) - agent_id = scheduler.assign_agent_for_session( + agent_id = await agselector.assign_agent_for_session( example_agents, picked_session, - AgentSelectionStrategy.DISPERSED, - agent_selection_resource_priority, ) assert agent_id == AgentId("i-001") -def test_lifo_scheduler(example_agents, example_pending_sessions, example_existing_sessions): +@pytest.mark.asyncio +async def test_lifo_scheduler( + example_agents: Sequence[AgentRow], + example_pending_sessions: Sequence[SessionRow], + example_existing_sessions: Sequence[SessionRow], +) -> None: scheduler = LIFOSlotScheduler(ScalingGroupOpts(), {}) + agstate_cls = DispersedAgentSelector.get_state_cls() + agselector = DispersedAgentSelector( + ScalingGroupOpts(), + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) picked_session_id = scheduler.pick_session( - example_total_capacity, + sum((ag.available_slots for ag in example_agents), start=ResourceSlot()), example_pending_sessions, example_existing_sessions, ) @@ -739,23 +985,30 @@ def test_lifo_scheduler(example_agents, example_pending_sessions, example_existi example_pending_sessions, picked_session_id, ) - agent_id = scheduler.assign_agent_for_session( + agent_id = await agselector.assign_agent_for_session( example_agents, picked_session, - AgentSelectionStrategy.DISPERSED, - agent_selection_resource_priority, ) assert agent_id == "i-001" -def test_fifo_scheduler_favor_cpu_for_requests_without_accelerators( - example_mixed_agents, - example_pending_sessions, -): +@pytest.mark.asyncio +async def test_fifo_scheduler_favor_cpu_for_requests_without_accelerators( + example_mixed_agents: Sequence[AgentRow], + example_pending_sessions: Sequence[SessionRow], +) -> None: scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {}) + agstate_cls = DispersedAgentSelector.get_state_cls() + agselector = DispersedAgentSelector( + ScalingGroupOpts(), + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) + total_capacity = sum((ag.available_slots for ag in example_mixed_agents), start=ResourceSlot()) for idx in range(3): picked_session_id = scheduler.pick_session( - example_total_capacity, + total_capacity, example_pending_sessions, [], ) @@ -764,11 +1017,9 @@ def test_fifo_scheduler_favor_cpu_for_requests_without_accelerators( example_pending_sessions, picked_session_id, ) - agent_id = scheduler.assign_agent_for_session( + agent_id = await agselector.assign_agent_for_session( example_mixed_agents, picked_session, - AgentSelectionStrategy.DISPERSED, - agent_selection_resource_priority, ) if idx == 0: # example_mixed_agents do not have any agent with ROCM accelerators. @@ -800,7 +1051,7 @@ def gen_pending_for_holb_tests(session_id: str, status_data: Mapping[str, Any]) ) -def test_fifo_scheduler_hol_blocking_avoidance_empty_status_data(): +def test_fifo_scheduler_hol_blocking_avoidance_empty_status_data() -> None: """ Without any status_data, it should just pick the first session. """ @@ -814,7 +1065,7 @@ def test_fifo_scheduler_hol_blocking_avoidance_empty_status_data(): assert picked_session_id == "s0" -def test_fifo_scheduler_hol_blocking_avoidance_config(): +def test_fifo_scheduler_hol_blocking_avoidance_config() -> None: """ If the upfront sessions have enough number of retries, it should skip them. @@ -838,7 +1089,7 @@ def test_fifo_scheduler_hol_blocking_avoidance_config(): assert picked_session_id == "s1" -def test_fifo_scheduler_hol_blocking_avoidance_skips(): +def test_fifo_scheduler_hol_blocking_avoidance_skips() -> None: """ If the upfront sessions have enough number of retries, it should skip them. @@ -861,7 +1112,7 @@ def test_fifo_scheduler_hol_blocking_avoidance_skips(): assert picked_session_id == "s2" -def test_fifo_scheduler_hol_blocking_avoidance_all_skipped(): +def test_fifo_scheduler_hol_blocking_avoidance_all_skipped() -> None: """ If all sessions are skipped due to excessive number of retries, then we go back to the normal FIFO by choosing the first of them. @@ -876,7 +1127,7 @@ def test_fifo_scheduler_hol_blocking_avoidance_all_skipped(): assert picked_session_id == "s0" -def test_fifo_scheduler_hol_blocking_avoidance_no_skip(): +def test_fifo_scheduler_hol_blocking_avoidance_no_skip() -> None: """ If non-first sessions have to be skipped, the scheduler should still choose the first session. @@ -891,24 +1142,30 @@ def test_fifo_scheduler_hol_blocking_avoidance_no_skip(): assert picked_session_id == "s0" -def test_lifo_scheduler_favor_cpu_for_requests_without_accelerators( - example_mixed_agents, - example_pending_sessions, -): +@pytest.mark.asyncio +async def test_lifo_scheduler_favor_cpu_for_requests_without_accelerators( + example_mixed_agents: Sequence[AgentRow], + example_pending_sessions: Sequence[SessionRow], +) -> None: # Check the reverse with the LIFO scheduler. # The result must be same. - scheduler = LIFOSlotScheduler(ScalingGroupOpts(), {}) + sgroup_opts = ScalingGroupOpts(agent_selection_strategy=AgentSelectionStrategy.DISPERSED) + scheduler = LIFOSlotScheduler(sgroup_opts, {}) + agstate_cls = DispersedAgentSelector.get_state_cls() + agselector = DispersedAgentSelector( + sgroup_opts, + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) + total_capacity = sum((ag.available_slots for ag in example_mixed_agents), start=ResourceSlot()) for idx in range(3): - picked_session_id = scheduler.pick_session( - example_total_capacity, example_pending_sessions, [] - ) + picked_session_id = scheduler.pick_session(total_capacity, example_pending_sessions, []) assert picked_session_id == example_pending_sessions[-1].id picked_session = _find_and_pop_picked_session(example_pending_sessions, picked_session_id) - agent_id = scheduler.assign_agent_for_session( + agent_id = await agselector.assign_agent_for_session( example_mixed_agents, picked_session, - AgentSelectionStrategy.DISPERSED, - agent_selection_resource_priority, ) if idx == 2: # example_mixed_agents do not have any agent with ROCM accelerators. @@ -921,14 +1178,23 @@ def test_lifo_scheduler_favor_cpu_for_requests_without_accelerators( assert agent_id == AgentId("i-cpu") -def test_drf_scheduler( - example_agents, - example_pending_sessions, - example_existing_sessions, -): - scheduler = DRFScheduler(ScalingGroupOpts(), {}) +@pytest.mark.asyncio +async def test_drf_scheduler( + example_agents: Sequence[AgentRow], + example_pending_sessions: Sequence[SessionRow], + example_existing_sessions: Sequence[SessionRow], +) -> None: + sgroup_opts = ScalingGroupOpts(agent_selection_strategy=AgentSelectionStrategy.DISPERSED) + scheduler = DRFScheduler(sgroup_opts, {}) + agstate_cls = DispersedAgentSelector.get_state_cls() + agselector = DispersedAgentSelector( + sgroup_opts, + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) picked_session_id = scheduler.pick_session( - example_total_capacity, + sum((ag.available_slots for ag in example_agents), start=ResourceSlot()), example_pending_sessions, example_existing_sessions, ) @@ -938,87 +1204,21 @@ def test_drf_scheduler( example_pending_sessions, picked_session_id, ) - agent_id = scheduler.assign_agent_for_session( - example_agents, - picked_session, - AgentSelectionStrategy.DISPERSED, - agent_selection_resource_priority, - ) - assert agent_id == "i-001" - - -def test_mof_scheduler_first_assign( - example_agents, - example_pending_sessions, - example_existing_sessions, -): - scheduler = MOFScheduler(ScalingGroupOpts(), {}) - picked_session_id = scheduler.pick_session( - example_total_capacity, example_pending_sessions, example_existing_sessions - ) - assert picked_session_id == example_pending_sessions[0].id - picked_session = _find_and_pop_picked_session(example_pending_sessions, picked_session_id) - - agent_id = scheduler.assign_agent_for_session( + agent_id = await agselector.assign_agent_for_session( example_agents, picked_session, - AgentSelectionStrategy.DISPERSED, - agent_selection_resource_priority, ) assert agent_id == "i-001" -def test_mof_scheduler_second_assign( - example_agents_first_one_assigned, - example_pending_sessions, - example_existing_sessions, -): - scheduler = MOFScheduler(ScalingGroupOpts(), {}) - picked_session_id = scheduler.pick_session( - example_total_capacity, example_pending_sessions, example_existing_sessions - ) - assert picked_session_id == example_pending_sessions[0].id - picked_session = _find_and_pop_picked_session(example_pending_sessions, picked_session_id) - - agent_id = scheduler.assign_agent_for_session( - example_agents_first_one_assigned, - picked_session, - AgentSelectionStrategy.DISPERSED, - agent_selection_resource_priority, - ) - assert agent_id == "i-101" - - -def test_mof_scheduler_no_valid_agent( - example_agents_no_valid, - example_pending_sessions, - example_existing_sessions, -): - scheduler = MOFScheduler(ScalingGroupOpts(), {}) - picked_session_id = scheduler.pick_session( - example_total_capacity, example_pending_sessions, example_existing_sessions - ) - assert picked_session_id == example_pending_sessions[0].id - picked_session = _find_and_pop_picked_session(example_pending_sessions, picked_session_id) - - agent_id = scheduler.assign_agent_for_session( - example_agents_no_valid, - picked_session, - AgentSelectionStrategy.DISPERSED, - agent_selection_resource_priority, - ) - assert agent_id is None - - @pytest.mark.asyncio -async def test_pending_timeout(mocker): +async def test_pending_timeout() -> None: class DummySession: def __init__(self, id, created_at, status) -> None: self.id = id self.created_at = created_at self.status = status - # mocker.patch("ai.backend.manager.scheduler.dispatcher.datetime", MockDatetime) now = datetime.now(tzutc()) mock_query_result = MagicMock() mock_query_result.scalars = MagicMock() @@ -1048,7 +1248,8 @@ def __init__(self, id, created_at, status) -> None: mock_dbsess.execute = AsyncMock(return_value=mock_query_result) scheduler = FIFOSlotScheduler( - ScalingGroupOpts(pending_timeout=timedelta(seconds=86400 * 2)), {} + ScalingGroupOpts(pending_timeout=timedelta(seconds=86400 * 2)), + {}, ) _, candidate_session_rows, cancelled_session_rows = await _list_managed_sessions( mock_dbsess, @@ -1059,7 +1260,10 @@ def __init__(self, id, created_at, status) -> None: assert len(cancelled_session_rows) == 1 assert cancelled_session_rows[0].id == "session1" - scheduler = FIFOSlotScheduler(ScalingGroupOpts(pending_timeout=timedelta(seconds=0)), {}) + scheduler = FIFOSlotScheduler( + ScalingGroupOpts(pending_timeout=timedelta(seconds=0)), + {}, + ) _, candidate_session_rows, cancelled_session_rows = await _list_managed_sessions( mock_dbsess, "default", @@ -1080,10 +1284,10 @@ async def test_manually_assign_agent_available( registry_ctx: tuple[ AgentRegistry, MagicMock, MagicMock, MagicMock, MagicMock, MagicMock, MagicMock ], - mocker, - example_agents, - example_pending_sessions, -): + mocker: pytest_mock.MockerFixture, + example_agents: Sequence[AgentRow], + example_pending_sessions: Sequence[SessionRow], +) -> None: mock_local_config = MagicMock() ( @@ -1100,7 +1304,14 @@ async def test_manually_assign_agent_available( mock_redis_wrapper = MagicMock() mock_redis_wrapper.execute = AsyncMock(return_value=[0 for _ in example_agents]) mocker.patch("ai.backend.manager.scheduler.dispatcher.redis_helper", mock_redis_wrapper) - scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {}) + sgroup_opts = ScalingGroupOpts() + agstate_cls = DispersedAgentSelector.get_state_cls() + agselector = DispersedAgentSelector( + sgroup_opts, + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) sgroup_name = example_agents[0].scaling_group candidate_agents = example_agents example_pending_sessions[0].kernels[0].agent = example_agents[0].id @@ -1119,11 +1330,10 @@ async def test_manually_assign_agent_available( mock_dbresult.scalar = MagicMock(return_value=None) await dispatcher._schedule_single_node_session( mock_sched_ctx, - scheduler, + agselector, sgroup_name, candidate_agents, sess_ctx, - agent_selection_resource_priority, mock_check_result, ) result = mock_dbresult.scalar() @@ -1133,11 +1343,10 @@ async def test_manually_assign_agent_available( mock_dbresult.scalar = MagicMock(return_value={}) await dispatcher._schedule_single_node_session( mock_sched_ctx, - scheduler, + agselector, sgroup_name, candidate_agents, sess_ctx, - agent_selection_resource_priority, mock_check_result, ) result = mock_dbresult.scalar() @@ -1154,11 +1363,10 @@ async def test_manually_assign_agent_available( ) await dispatcher._schedule_single_node_session( mock_sched_ctx, - scheduler, + agselector, sgroup_name, candidate_agents, sess_ctx, - agent_selection_resource_priority, mock_check_result, ) result = mock_dbresult.scalar() @@ -1176,11 +1384,10 @@ async def test_manually_assign_agent_available( ) await dispatcher._schedule_single_node_session( mock_sched_ctx, - scheduler, + agselector, sgroup_name, candidate_agents, sess_ctx, - agent_selection_resource_priority, mock_check_result, ) result = mock_dbresult.scalar() @@ -1190,7 +1397,7 @@ async def test_manually_assign_agent_available( @pytest.mark.asyncio @mock.patch("ai.backend.manager.scheduler.predicates.datetime") -async def test_multiple_timezones_for_reserved_batch_session_predicate(mock_dt): +async def test_multiple_timezones_for_reserved_batch_session_predicate(mock_dt: MagicMock) -> None: mock_db_conn = MagicMock() mock_sched_ctx = MagicMock() mock_sess_ctx = MagicMock() @@ -1234,4 +1441,131 @@ async def test_multiple_timezones_for_reserved_batch_session_predicate(mock_dt): assert result.passed -# TODO: write tests for multiple agents and scaling groups +@pytest.mark.asyncio +@pytest.mark.parametrize("example_agents_multi_homogeneous", [{"repeat": 10}], indirect=True) +@pytest.mark.parametrize("example_homogeneous_pending_sessions", [{"repeat": 20}], indirect=True) +async def test_agent_selection_strategy_rr( + example_agents_multi_homogeneous: Sequence[AgentRow], + example_homogeneous_pending_sessions: Sequence[SessionRow], + example_existing_sessions: Sequence[SessionRow], +) -> None: + sgroup_opts = ScalingGroupOpts( + agent_selection_strategy=AgentSelectionStrategy.ROUNDROBIN, + ) + scheduler = FIFOSlotScheduler( + sgroup_opts, + {}, + ) + + agstate_cls = RoundRobinAgentSelector.get_state_cls() + agselector = RoundRobinAgentSelector( + sgroup_opts, + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) + + num_agents = len(example_agents_multi_homogeneous) + total_capacity = sum( + (ag.available_slots for ag in example_agents_multi_homogeneous), ResourceSlot() + ) + agent_ids = [] + # Repeat the allocation for two iterations + for _ in range(num_agents * 2): + picked_session_id = scheduler.pick_session( + total_capacity, + example_homogeneous_pending_sessions, + example_existing_sessions, + ) + assert picked_session_id == example_homogeneous_pending_sessions[0].id + picked_session = _find_and_pop_picked_session( + example_homogeneous_pending_sessions, + picked_session_id, + ) + agent_ids.append( + await agselector.assign_agent_for_session( + example_agents_multi_homogeneous, + picked_session, + ) + ) + assert agent_ids == [AgentId(f"i-{idx:03d}") for idx in range(num_agents)] * 2 + + +@pytest.mark.asyncio +async def test_agent_selection_strategy_rr_skip_unacceptable_agents( + example_agents_many: Sequence[AgentRow], +) -> None: + # example_agents_many: + # i-001: cpu=8, mem=4096, cuda.shares=4.0 + # i-002: cpu=4, mem=2048, cuda.shares=2.0 + # i-003: cpu=2, mem=1024, cuda.shares=1.0 + # i-004: cpu=1, mem=512, cuda.shares=0.5 + agents: list[AgentRow] = [*example_agents_many] + + # all pending sessions: + # cpu=2, mem=500 + pending_sessions = [ + create_pending_session( + SessionId(uuid4()), + KernelId(uuid4()), + ResourceSlot({ + "cpu": Decimal("2"), + "mem": Decimal("500"), + }), + ) + for _ in range(8) + ] + + sgroup_opts = ScalingGroupOpts( + agent_selection_strategy=AgentSelectionStrategy.ROUNDROBIN, + ) + scheduler = FIFOSlotScheduler( + sgroup_opts, + {}, + ) + + agstate_cls = RoundRobinAgentSelector.get_state_cls() + agselector = RoundRobinAgentSelector( + sgroup_opts, + {}, + agent_selection_resource_priority, + state_store=InMemoryResourceGroupStateStore(agstate_cls), + ) + + total_capacity = sum((ag.available_slots for ag in agents), ResourceSlot()) + + results: list[AgentId | None] = [] + scheduled_sessions: list[SessionRow] = [] + + for _ in range(8): + picked_session_id = scheduler.pick_session( + total_capacity, + pending_sessions, + scheduled_sessions, + ) + assert picked_session_id is not None + picked_session = _find_and_pop_picked_session( + pending_sessions, + picked_session_id, + ) + scheduled_sessions.append(picked_session) + result = await agselector.assign_agent_for_session( + agents, + picked_session, + ) + if result is not None: + _update_agent_assignment(agents, result, picked_session.requested_slots) + results.append(result) + + print() + for ag in agents: + print( + ag.id, + f"{ag.occupied_slots["cpu"]}/{ag.available_slots["cpu"]}", + f"{ag.occupied_slots["mem"]}/{ag.available_slots["mem"]}", + ) + # As more sessions have the assigned agents, the remaining capacity diminishes + # and the range of round-robin also becomes limited. + # When there is no assignable agent, it should return None. + assert len(results) == 8 + assert results == ["i-001", "i-002", "i-003", "i-001", "i-002", "i-001", "i-001", None]