Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Split PendingSession Scheduler into PendingSession Scheduler and AgentSelector #1655

Merged
merged 46 commits into from
Sep 16, 2024
Merged
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4251d4a
Replace RoundRobin flag with AgentSelectionStrategy.RoundRobin
jopemachine Oct 25, 2023
c6b4c8e
Remove MOFScheduler
jopemachine Oct 26, 2023
328826f
Add use_num_extras flag temporarily
jopemachine Oct 26, 2023
2bb06a9
Trim imports
jopemachine Oct 26, 2023
26fbed2
Fix wrong type
jopemachine Oct 26, 2023
2ec9bf2
Distinguish `compatible_agents` and `possible_agents`
jopemachine Oct 26, 2023
ae9a2cc
Format with ruff
jopemachine Mar 27, 2024
4addb8b
fix: Remove unused import
achimnol Aug 24, 2024
5554e3c
fix: merge conflicts
achimnol Aug 24, 2024
68b5e83
test: Use the actual sum of total capacity from the agent list fixture
achimnol Aug 25, 2024
980948e
refactor: Move common.types.RoundRobinContext to manager.scheduler.types
achimnol Aug 26, 2024
9f3e500
refactor: Split out `AgentSelector` interface and implmentations.
achimnol Aug 26, 2024
be85eb6
fix: oops
achimnol Aug 26, 2024
9418066
refactor: Use modern type imports
achimnol Aug 26, 2024
8f6fcd2
test: Update test codes
achimnol Aug 26, 2024
ad59f73
chore: Correct BUILD file and fix typos
jopemachine Aug 27, 2024
ac4c848
refactor: WIP
jopemachine Aug 28, 2024
08d99d1
chore: Rename news fragment
jopemachine Aug 28, 2024
151eeef
docs: Add comment
jopemachine Aug 28, 2024
d9984eb
chore: Remove useless `__init__`
jopemachine Aug 28, 2024
e72f873
refactor: Move `roundrobin_states` -> `agent_selector_states`
jopemachine Aug 30, 2024
2af1105
fix: Edit comment
jopemachine Aug 30, 2024
6011f5f
refactor: Remove useless type alias
jopemachine Aug 30, 2024
39c81bc
feat: Add `from_json`, `as_trafaret` to `AgentSelectorState`
jopemachine Aug 30, 2024
87e9322
chore: refactor
jopemachine Aug 30, 2024
0596185
chore: change to snake case
jopemachine Aug 30, 2024
ab03430
chore: Update comment
jopemachine Aug 30, 2024
6a1a985
fix: `store-type` -> `store_type`
jopemachine Aug 30, 2024
68a376b
feat: Separate the Agentselector configuration into local config and …
jopemachine Aug 30, 2024
a87b1a4
refactor: Remove code duplication
jopemachine Aug 30, 2024
0dfcf10
feat: Change to etcd key path to kebab case, Remove storage_type from…
jopemachine Sep 3, 2024
682fa80
feat: Enable RoundRobinAgentSelector for multi node session
jopemachine Sep 3, 2024
1ead959
test: Clarify the intention
achimnol Sep 16, 2024
5c91bec
fix: Require state_store as the mandatory kwarg and add missing defau…
achimnol Sep 16, 2024
e4e6d1c
refactor: Minimize the scope of interests for module-specific type defs
achimnol Sep 16, 2024
7cde57e
fix: Add comments and debug logging.
achimnol Sep 16, 2024
a24e2e4
refactor: Clean up ResourceGroupState type hierarchy
achimnol Sep 16, 2024
b6f140e
refactor,fix: Add missing state_name key for resource group states
achimnol Sep 16, 2024
ae9c537
fix: remove debug print
achimnol Sep 16, 2024
e727671
refactor: Clarify that this is a generic ResourceGroupStateStore
achimnol Sep 16, 2024
ec9944f
refactor: Reduce verbosity
achimnol Sep 16, 2024
369315f
refactor: Use pydantic instead of trafaret-based mixins
achimnol Sep 16, 2024
cc15f8d
fix: Add database migration for agent selector configs
achimnol Sep 16, 2024
15fbc7b
fix: typo in the class name
achimnol Sep 16, 2024
b40cfe0
fix: Ensure removal of 'roundrobin' key in the scheduler_opts
achimnol Sep 16, 2024
69083d9
test: Update the test case to be more realistic
achimnol Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: Minimize the scope of interests for module-specific type defs
achimnol committed Sep 16, 2024
commit e4e6d1cf2cbf106327da42025d308601a2485b5f
42 changes: 0 additions & 42 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,6 @@
from pydantic import BaseModel, ConfigDict, Field
from redis.asyncio import Redis

from ..logging.types import CIStrEnum
from .exception import InvalidIpAddressValue
from .models.minilang.mount import MountPointParser

@@ -1217,47 +1216,6 @@ class VolumeMountableNodeType(enum.StrEnum):
STORAGE_PROXY = enum.auto()


@dataclass
class AgentSelectorState:
roundrobin_states: dict[str, RoundRobinState] | None = None

def to_json(self) -> dict[str, Any]:
return dataclasses.asdict(self)

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> AgentSelectorState:
return cls(**cls.as_trafaret().check(obj))

@classmethod
def as_trafaret(cls) -> t.Trafaret:
return t.Dict({
t.Key("roundrobin_states"): t.Mapping(t.String, RoundRobinState.as_trafaret()),
})


@dataclass
class RoundRobinState(JSONSerializableMixin):
next_index: int

def to_json(self) -> dict[str, Any]:
return dataclasses.asdict(self)

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> RoundRobinState:
return cls(**cls.as_trafaret().check(obj))

@classmethod
def as_trafaret(cls) -> t.Trafaret:
return t.Dict({
t.Key("next_index"): t.Int,
})


class StateStoreType(CIStrEnum):
ETCD = enum.auto()
INMEMORY = enum.auto()


SSLContextType: TypeAlias = bool | Fingerprint | SSLContext


21 changes: 0 additions & 21 deletions src/ai/backend/common/validators.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,6 @@
from trafaret.base import TrafaretMeta, ensure_trafaret
from trafaret.lib import _empty

from .types import AgentSelectorState, RoundRobinState
from .types import BinarySize as _BinarySize
from .types import HostPortPair as _HostPortPair
from .types import QuotaScopeID as _QuotaScopeID
@@ -727,23 +726,3 @@ def check_and_return(self, value: Any) -> float:
return 0
case _:
self._failure(f"Value must be (float, tuple of float or None), not {type(value)}.")


class AgentSelectorStateJSONString(t.Trafaret):
def check_and_return(self, value: Any) -> AgentSelectorState:
try:
agent_selector_state_dict: dict[str, dict[str, Any]] = json.loads(value)
except (KeyError, ValueError, json.decoder.JSONDecodeError):
self._failure(f'Expected valid JSON string, but found "{value}"')

roundrobin_states: dict[str, RoundRobinState] = {}
if roundrobin_states_dict := agent_selector_state_dict.get("roundrobin_states", None):
for arch, roundrobin_state_dict in roundrobin_states_dict.items():
if "next_index" in roundrobin_state_dict:
roundrobin_states[arch] = RoundRobinState.from_json(roundrobin_state_dict)
else:
self._failure("Got invalid roundrobin state: {}", roundrobin_state_dict)

return AgentSelectorState(
roundrobin_states=roundrobin_states,
)
3 changes: 0 additions & 3 deletions src/ai/backend/manager/config.py
Original file line number Diff line number Diff line change
@@ -473,9 +473,6 @@ def container_registry_serialize(v: dict[str, Any]) -> dict[str, str]:
): session_hang_tolerance_iv,
},
).allow_extra("*"),
t.Key("agent-selector-states", default={}): t.Mapping(
t.String, tx.AgentSelectorStateJSONString
),
}).allow_extra("*")

_volume_defaults: dict[str, Any] = {
2 changes: 1 addition & 1 deletion src/ai/backend/manager/scheduler/agent_selector.py
Original file line number Diff line number Diff line change
@@ -9,12 +9,12 @@
from ai.backend.common.types import (
AgentId,
ResourceSlot,
RoundRobinState,
)

from ..models import AgentRow, KernelRow, SessionRow
from .types import (
AbstractAgentSelector,
RoundRobinState,
)
from .utils import (
get_requested_architecture,
68 changes: 66 additions & 2 deletions src/ai/backend/manager/scheduler/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import dataclasses
import enum
import json
import logging
import uuid
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
@@ -33,8 +36,8 @@
from ai.backend.common.types import (
AccessKey,
AgentId,
AgentSelectorState,
ClusterMode,
JSONSerializableMixin,
KernelId,
ResourceSlot,
SessionId,
@@ -43,8 +46,8 @@
SlotTypes,
VFolderMount,
)
from ai.backend.common.validators import AgentSelectorStateJSONString
from ai.backend.logging import BraceStyleAdapter
from ai.backend.logging.types import CIStrEnum
from ai.backend.manager.config import SharedConfig

from ..defs import DEFAULT_ROLE
@@ -533,6 +536,67 @@ async def select_agent(
ResourceGroupID = NewType("ResourceGroupID", str)


@dataclass
class AgentSelectorState:
roundrobin_states: dict[str, RoundRobinState] | None = None

def to_json(self) -> dict[str, Any]:
return dataclasses.asdict(self)

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> AgentSelectorState:
return cls(**cls.as_trafaret().check(obj))

@classmethod
def as_trafaret(cls) -> t.Trafaret:
return t.Dict({
t.Key("roundrobin_states"): t.Mapping(t.String, RoundRobinState.as_trafaret()),
})


@dataclass
class RoundRobinState(JSONSerializableMixin):
next_index: int

def to_json(self) -> dict[str, Any]:
return dataclasses.asdict(self)

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> RoundRobinState:
return cls(**cls.as_trafaret().check(obj))

@classmethod
def as_trafaret(cls) -> t.Trafaret:
return t.Dict({
t.Key("next_index"): t.Int,
})


class StateStoreType(CIStrEnum):
DEFAULT = enum.auto()
INMEMORY = enum.auto()


class AgentSelectorStateJSONString(t.Trafaret):
def check_and_return(self, value: Any) -> AgentSelectorState:
try:
agent_selector_state_dict: dict[str, dict[str, Any]] = json.loads(value)
except (KeyError, ValueError, json.decoder.JSONDecodeError):
self._failure(f'Expected valid JSON string, but found "{value}"')

roundrobin_states: dict[str, RoundRobinState] = {}
if roundrobin_states_dict := agent_selector_state_dict.get("roundrobin_states", None):
for arch, roundrobin_state_dict in roundrobin_states_dict.items():
if "next_index" in roundrobin_state_dict:
roundrobin_states[arch] = RoundRobinState.from_json(roundrobin_state_dict)
else:
self._failure("Got invalid roundrobin state: {}", roundrobin_state_dict)

return AgentSelectorState(
roundrobin_states=roundrobin_states,
)


class AbstractResourceGroupStateStore(Generic[StateType], metaclass=ABCMeta):
"""
Store and load the state of the pending session scheduler and agent selector for each resource group.