Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Split PendingSession Scheduler into PendingSession Scheduler and AgentSelector #1655

Merged
merged 46 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4251d4a
Replace RoundRobin flag with AgentSelectionStrategy.RoundRobin
jopemachine Oct 25, 2023
c6b4c8e
Remove MOFScheduler
jopemachine Oct 26, 2023
328826f
Add use_num_extras flag temporarily
jopemachine Oct 26, 2023
2bb06a9
Trim imports
jopemachine Oct 26, 2023
26fbed2
Fix wrong type
jopemachine Oct 26, 2023
2ec9bf2
Distinguish `compatible_agents` and `possible_agents`
jopemachine Oct 26, 2023
ae9a2cc
Format with ruff
jopemachine Mar 27, 2024
4addb8b
fix: Remove unused import
achimnol Aug 24, 2024
5554e3c
fix: merge conflicts
achimnol Aug 24, 2024
68b5e83
test: Use the actual sum of total capacity from the agent list fixture
achimnol Aug 25, 2024
980948e
refactor: Move common.types.RoundRobinContext to manager.scheduler.types
achimnol Aug 26, 2024
9f3e500
refactor: Split out `AgentSelector` interface and implmentations.
achimnol Aug 26, 2024
be85eb6
fix: oops
achimnol Aug 26, 2024
9418066
refactor: Use modern type imports
achimnol Aug 26, 2024
8f6fcd2
test: Update test codes
achimnol Aug 26, 2024
ad59f73
chore: Correct BUILD file and fix typos
jopemachine Aug 27, 2024
ac4c848
refactor: WIP
jopemachine Aug 28, 2024
08d99d1
chore: Rename news fragment
jopemachine Aug 28, 2024
151eeef
docs: Add comment
jopemachine Aug 28, 2024
d9984eb
chore: Remove useless `__init__`
jopemachine Aug 28, 2024
e72f873
refactor: Move `roundrobin_states` -> `agent_selector_states`
jopemachine Aug 30, 2024
2af1105
fix: Edit comment
jopemachine Aug 30, 2024
6011f5f
refactor: Remove useless type alias
jopemachine Aug 30, 2024
39c81bc
feat: Add `from_json`, `as_trafaret` to `AgentSelectorState`
jopemachine Aug 30, 2024
87e9322
chore: refactor
jopemachine Aug 30, 2024
0596185
chore: change to snake case
jopemachine Aug 30, 2024
ab03430
chore: Update comment
jopemachine Aug 30, 2024
6a1a985
fix: `store-type` -> `store_type`
jopemachine Aug 30, 2024
68a376b
feat: Separate the Agentselector configuration into local config and …
jopemachine Aug 30, 2024
a87b1a4
refactor: Remove code duplication
jopemachine Aug 30, 2024
0dfcf10
feat: Change to etcd key path to kebab case, Remove storage_type from…
jopemachine Sep 3, 2024
682fa80
feat: Enable RoundRobinAgentSelector for multi node session
jopemachine Sep 3, 2024
1ead959
test: Clarify the intention
achimnol Sep 16, 2024
5c91bec
fix: Require state_store as the mandatory kwarg and add missing defau…
achimnol Sep 16, 2024
e4e6d1c
refactor: Minimize the scope of interests for module-specific type defs
achimnol Sep 16, 2024
7cde57e
fix: Add comments and debug logging.
achimnol Sep 16, 2024
a24e2e4
refactor: Clean up ResourceGroupState type hierarchy
achimnol Sep 16, 2024
b6f140e
refactor,fix: Add missing state_name key for resource group states
achimnol Sep 16, 2024
ae9c537
fix: remove debug print
achimnol Sep 16, 2024
e727671
refactor: Clarify that this is a generic ResourceGroupStateStore
achimnol Sep 16, 2024
ec9944f
refactor: Reduce verbosity
achimnol Sep 16, 2024
369315f
refactor: Use pydantic instead of trafaret-based mixins
achimnol Sep 16, 2024
cc15f8d
fix: Add database migration for agent selector configs
achimnol Sep 16, 2024
15fbc7b
fix: typo in the class name
achimnol Sep 16, 2024
b40cfe0
fix: Ensure removal of 'roundrobin' key in the scheduler_opts
achimnol Sep 16, 2024
69083d9
test: Update the test case to be more realistic
achimnol Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/1655.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `PendingSession` Scheduler into `PendingSession` scheduler and `AgentSelector`, and replace `roundrobin` flag with `AgentSelectionStrategy.RoundRobin` policy.
8 changes: 8 additions & 0 deletions src/ai/backend/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ class BaseSchema(BaseModel):
),
}).allow_extra("*")

# Used in Etcd as a global config.
# If `scalingGroup.scheduler_opts` contains an `agent_selector_config`, it will override this.
agent_selector_globalconfig_iv = t.Dict({}).allow_extra("*")

# Used in `scalingGroup.scheduler_opts` as a per scaling_group config.
agent_selector_config_iv = t.Dict({}) | agent_selector_globalconfig_iv


model_definition_iv = t.Dict({
t.Key("models"): t.List(
t.Dict({
Expand Down
31 changes: 6 additions & 25 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NewType,
NotRequired,
Optional,
Self,
Sequence,
Tuple,
Type,
Expand Down Expand Up @@ -218,7 +219,9 @@ def check_typed_dict(value: Mapping[Any, Any], expected_type: Type[TD]) -> TD:
SessionId = NewType("SessionId", uuid.UUID)
KernelId = NewType("KernelId", uuid.UUID)
ImageAlias = NewType("ImageAlias", str)
ArchName = NewType("ArchName", str)

ResourceGroupID = NewType("ResourceGroupID", str)
AgentId = NewType("AgentId", str)
DeviceName = NewType("DeviceName", str)
DeviceId = NewType("DeviceId", str)
Expand Down Expand Up @@ -840,7 +843,7 @@ def to_json(self) -> dict[str, Any]:
raise NotImplementedError

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> JSONSerializableMixin:
def from_json(cls, obj: Mapping[str, Any]) -> Self:
return cls(**cls.as_trafaret().check(obj))

@classmethod
Expand Down Expand Up @@ -985,7 +988,7 @@ def to_json(self) -> dict[str, Any]:
return {host: [perm.value for perm in perms] for host, perms in self.items()}

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> JSONSerializableMixin:
def from_json(cls, obj: Mapping[str, Any]) -> Self:
return cls(**cls.as_trafaret().check(obj))

@classmethod
Expand Down Expand Up @@ -1197,6 +1200,7 @@ class AcceleratorMetadata(TypedDict):
class AgentSelectionStrategy(enum.StrEnum):
DISPERSED = "dispersed"
CONCENTRATED = "concentrated"
ROUNDROBIN = "roundrobin"
# LEGACY chooses the largest agent (the sort key is a tuple of resource slots).
LEGACY = "legacy"

Expand All @@ -1215,29 +1219,6 @@ class VolumeMountableNodeType(enum.StrEnum):
STORAGE_PROXY = enum.auto()


@dataclass
class RoundRobinState(JSONSerializableMixin):
schedulable_group_id: str
next_index: int

def to_json(self) -> dict[str, Any]:
return dataclasses.asdict(self)

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> RoundRobinState:
return cls(**cls.as_trafaret().check(obj))

@classmethod
def as_trafaret(cls) -> t.Trafaret:
return t.Dict({
t.Key("schedulable_group_id"): t.String,
t.Key("next_index"): t.Int,
})


# States of the round-robin scheduler for each resource group and architecture.
RoundRobinStates: TypeAlias = dict[str, dict[str, RoundRobinState]]

SSLContextType: TypeAlias = bool | Fingerprint | SSLContext


Expand Down
23 changes: 0 additions & 23 deletions src/ai/backend/common/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from .types import BinarySize as _BinarySize
from .types import HostPortPair as _HostPortPair
from .types import QuotaScopeID as _QuotaScopeID
from .types import RoundRobinState, RoundRobinStates
from .types import VFolderID as _VFolderID

__all__ = (
Expand Down Expand Up @@ -727,25 +726,3 @@ def check_and_return(self, value: Any) -> float:
return 0
case _:
self._failure(f"Value must be (float, tuple of float or None), not {type(value)}.")


class RoundRobinStatesJSONString(t.Trafaret):
def check_and_return(self, value: Any) -> RoundRobinStates:
try:
rr_states_dict: dict[str, dict[str, dict[str, Any]]] = json.loads(value)
except (KeyError, ValueError, json.decoder.JSONDecodeError):
self._failure(
f"Expected valid JSON string, got `{value}`. RoundRobinStatesJSONString should"
" be a valid JSON string",
value=value,
)

rr_states: RoundRobinStates = {}
for resource_group, arch_rr_states_dict in rr_states_dict.items():
rr_states[resource_group] = {}
for arch, rr_state_dict in arch_rr_states_dict.items():
if "next_index" not in rr_state_dict or "schedulable_group_id" not in rr_state_dict:
self._failure("Invalid roundrobin states")
rr_states[resource_group][arch] = RoundRobinState.from_json(rr_state_dict)

return rr_states
7 changes: 6 additions & 1 deletion src/ai/backend/manager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ python_distribution(
"fifo": "ai.backend.manager.scheduler.fifo:FIFOSlotScheduler",
"lifo": "ai.backend.manager.scheduler.fifo:LIFOSlotScheduler",
"drf": "ai.backend.manager.scheduler.drf:DRFScheduler",
"mof": "ai.backend.manager.scheduler.mof:MOFScheduler",
},
"backendai_agentselector_v10": {
"legacy": "ai.backend.manager.scheduler.agent_selector:LegacyAgentSelector",
"roundrobin": "ai.backend.manager.scheduler.agent_selector:RoundRobinAgentSelector",
"concentrated": "ai.backend.manager.scheduler.agent_selector:ConcentratedAgentSelector",
"dispersed": "ai.backend.manager.scheduler.agent_selector:DispersedAgentSelector",
},
"backendai_error_monitor_v20": {
"intrinsic": "ai.backend.manager.plugin.error_monitor:ErrorMonitor",
Expand Down
42 changes: 4 additions & 38 deletions src/ai/backend/manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@

from __future__ import annotations

import json
import logging
import os
import secrets
Expand Down Expand Up @@ -212,7 +211,6 @@
from ai.backend.common.lock import EtcdLock, FileLock, RedisLock
from ai.backend.common.types import (
HostPortPair,
RoundRobinState,
SlotName,
SlotTypes,
current_resource_slots,
Expand Down Expand Up @@ -350,6 +348,7 @@
"plugins": {
"accelerator": {},
"scheduler": {},
"agent_selector": {},
},
"watcher": {
"token": None,
Expand Down Expand Up @@ -439,6 +438,9 @@ def container_registry_serialize(v: dict[str, Any]) -> dict[str, str]:
t.Key("scheduler", default=_config_defaults["plugins"]["scheduler"]): t.Mapping(
t.String, t.Mapping(t.String, t.Any)
),
t.Key("agent_selector", default=_config_defaults["plugins"]["agent_selector"]): t.Mapping(
t.String, config.agent_selector_globalconfig_iv
),
}).allow_extra("*"),
t.Key("network", default=_config_defaults["network"]): t.Dict({
t.Key("subnet", default=_config_defaults["network"]["subnet"]): t.Dict({
Expand Down Expand Up @@ -471,7 +473,6 @@ def container_registry_serialize(v: dict[str, Any]) -> dict[str, str]:
): session_hang_tolerance_iv,
},
).allow_extra("*"),
t.Key("roundrobin_states", default=None): t.Null | tx.RoundRobinStatesJSONString,
}).allow_extra("*")

_volume_defaults: dict[str, Any] = {
Expand Down Expand Up @@ -847,38 +848,3 @@ def get_redis_url(self, db: int = 0) -> yarl.URL:
self.data["redis"]["addr"][1]
).with_password(self.data["redis"]["password"]) / str(db)
return url

async def get_roundrobin_state(
self, resource_group_name: str, architecture: str
) -> RoundRobinState | None:
"""
Return the roundrobin state for the given resource group and architecture.
If given resource group's roundrobin states or roundrobin state of the given architecture is not found, return None.
"""
if (rr_state_str := await self.get_raw("roundrobin_states")) is not None:
rr_states_dict: dict[str, dict[str, Any]] = json.loads(rr_state_str)
resource_group_rr_states_dict = rr_states_dict.get(resource_group_name, None)

if resource_group_rr_states_dict is not None:
rr_state_dict = resource_group_rr_states_dict.get(architecture, None)

if rr_state_dict is not None:
return RoundRobinState(
schedulable_group_id=rr_state_dict["schedulable_group_id"],
next_index=rr_state_dict["next_index"],
)

return None

async def put_roundrobin_state(
self, resource_group_name: str, architecture: str, state: RoundRobinState
) -> None:
"""
Update the roundrobin states using the given resource group and architecture key.
"""
rr_states_dict = json.loads(await self.get_raw("roundrobin_states") or "{}")
if resource_group_name not in rr_states_dict:
rr_states_dict[resource_group_name] = {}

rr_states_dict[resource_group_name][architecture] = state.to_json()
await self.etcd.put("roundrobin_states", json.dumps(rr_states_dict))
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""migrate-roundrobin-strategy-to-agent-selector-config

Revision ID: c4b7ec740b36
Revises: 59a622c31820
Create Date: 2024-09-17 00:31:31.379466

"""

from alembic import op
from sqlalchemy.sql import text

# revision identifiers, used by Alembic.
revision = "c4b7ec740b36"
down_revision = "59a622c31820"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.execute(
text("""
UPDATE scaling_groups
SET scheduler_opts =
CASE
WHEN scheduler_opts ? 'roundrobin' AND scheduler_opts->>'roundrobin' = 'true' THEN
jsonb_set(
jsonb_set(
scheduler_opts - 'roundrobin',
'{agent_selection_strategy}', '"roundrobin"'::jsonb
),
'{agent_selector_config}', '{}'::jsonb
)
ELSE jsonb_set(
scheduler_opts - 'roundrobin',
'{agent_selector_config}', '{}'::jsonb
)
END;
""")
)


def downgrade() -> None:
op.execute(
text("""
UPDATE scaling_groups
SET scheduler_opts =
CASE
WHEN scheduler_opts->>'agent_selection_strategy' = 'roundrobin' THEN
jsonb_set(
jsonb_set(
scheduler_opts - 'agent_selector_config',
'{agent_selection_strategy}', '"legacy"'::jsonb
),
'{roundrobin}', 'true'::jsonb
)
ELSE scheduler_opts - 'agent_selector_config'
END;
""")
)
12 changes: 8 additions & 4 deletions src/ai/backend/manager/models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sqlalchemy.sql.expression import true

from ai.backend.common import validators as tx
from ai.backend.common.config import agent_selector_config_iv
from ai.backend.common.types import (
AgentSelectionStrategy,
JSONSerializableMixin,
Expand Down Expand Up @@ -98,17 +99,20 @@ class ScalingGroupOpts(JSONSerializableMixin):
],
)
pending_timeout: timedelta = timedelta(seconds=0)
config: Mapping[str, Any] = attr.Factory(dict)
config: Mapping[str, Any] = attr.field(factory=dict)

# Scheduler has a dedicated database column to store its name,
# but agent selector configuration is stored as a part of the scheduler_opts column.
agent_selection_strategy: AgentSelectionStrategy = AgentSelectionStrategy.DISPERSED
roundrobin: bool = False
agent_selector_config: Mapping[str, Any] = attr.field(factory=dict)

def to_json(self) -> dict[str, Any]:
return {
"allowed_session_types": [item.value for item in self.allowed_session_types],
"pending_timeout": self.pending_timeout.total_seconds(),
"config": self.config,
"agent_selection_strategy": self.agent_selection_strategy,
"roundrobin": self.roundrobin,
"agent_selector_config": self.agent_selector_config,
}

@classmethod
Expand All @@ -127,7 +131,7 @@ def as_trafaret(cls) -> t.Trafaret:
t.Key("agent_selection_strategy", default=AgentSelectionStrategy.DISPERSED): tx.Enum(
AgentSelectionStrategy
),
t.Key("roundrobin", default=False): t.Bool(),
t.Key("agent_selector_config", default={}): agent_selector_config_iv,
}).allow_extra("*")


Expand Down
Loading