Skip to content

Commit

Permalink
refactor overall idle checker to check sessions not kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Oct 16, 2023
1 parent 158a644 commit bea0bea
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 210 deletions.
200 changes: 104 additions & 96 deletions src/ai/backend/manager/idle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import enum
import logging
import math
import uuid
from abc import ABCMeta, abstractmethod
from collections import UserDict, defaultdict
from datetime import datetime, timedelta
Expand Down Expand Up @@ -31,6 +32,7 @@
import trafaret as t
from aiotools import TaskGroupError
from sqlalchemy.engine import Row
from sqlalchemy.orm import joinedload, load_only, noload, selectinload

import ai.backend.common.validators as tx
from ai.backend.common import msgpack, redis_helper
Expand All @@ -55,19 +57,21 @@
AccessKey,
BinarySize,
RedisConnectionInfo,
ResourceSlot,
SessionTypes,
)
from ai.backend.common.utils import nmget

from .defs import DEFAULT_ROLE, LockID
from .models.kernel import LIVE_STATUS, kernels
from .models.keypair import keypairs
from .models.resource_policy import keypair_resource_policies
from .models.user import users
from .defs import LockID
from .models.kernel import KernelRow
from .models.keypair import KeyPairRow
from .models.resource_policy import KeyPairResourcePolicyRow
from .models.session import SessionRow, SessionStatus
from .models.user import UserRow
from .types import DistributedLockFactory

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession

from ai.backend.common.types import AgentId, KernelId, SessionId

Expand Down Expand Up @@ -112,8 +116,8 @@ async def get_redis_now(redis_obj: RedisConnectionInfo) -> float:
return t[0] + (t[1] / (10**6))


async def get_db_now(dbconn: SAConnection) -> datetime:
return await dbconn.scalar(sa.select(sa.func.now()))
async def get_db_now(db_session: SASession) -> datetime:
return await db_session.scalar(sa.select(sa.func.now()))


class UtilizationExtraInfo(NamedTuple):
Expand Down Expand Up @@ -197,6 +201,7 @@ def __init__(
self._grace_period_checker: NewUserGracePeriodChecker = NewUserGracePeriodChecker(
event_dispatcher, self._redis_live, self._redis_stat
)
self.user_created_cache: dict[uuid.UUID, datetime] = {}

def add_checker(self, checker: BaseIdleChecker):
if self._frozen:
Expand Down Expand Up @@ -253,59 +258,70 @@ async def _do_idle_check(
event: DoIdleCheckEvent,
) -> None:
log.debug("do_idle_check(): triggered")
policy_cache: dict[AccessKey, Row] = {}
async with self._db.begin_readonly() as conn:
j = sa.join(kernels, users, kernels.c.user_uuid == users.c.uuid)
query = (
sa.select(
[
kernels.c.id,
kernels.c.access_key,
kernels.c.session_id,
kernels.c.session_type,
kernels.c.created_at,
kernels.c.occupied_slots,
kernels.c.cluster_size,
users.c.created_at.label("user_created_at"),
]
)
.select_from(j)
.where(
(kernels.c.status.in_(LIVE_STATUS))
& (kernels.c.cluster_role == DEFAULT_ROLE)
& (kernels.c.session_type != SessionTypes.INFERENCE),
)
policy_cache: dict[AccessKey, KeyPairResourcePolicyRow] = {}
query = (
sa.select(SessionRow)
.select_from(SessionRow)
.where(
(SessionRow.status.is_(SessionStatus.RUNNING))
& (SessionRow.session_type.is_not(SessionTypes.INFERENCE))
)
.options(
noload("*"),
load_only(
SessionRow.id,
SessionRow.access_key,
SessionRow.session_type,
SessionRow.created_at,
SessionRow.occupying_slots,
SessionRow.user_uuid,
),
selectinload(SessionRow.kernels).options(noload("*"), load_only(KernelRow.id)),
)
result = await conn.execute(query)
rows = result.fetchall()
for kernel in rows:
grace_period_end = await self._grace_period_checker.get_grace_period_end(kernel)
policy = policy_cache.get(kernel["access_key"], None)
if policy is None:
)
async with self._db.begin_readonly_session() as db_session:
session_rows: list[SessionRow] = (await db_session.scalars(query)).all()
for session in session_rows:
if session.user_uuid not in self.user_created_cache:
user_created_query = (
sa.select(UserRow)
.where(UserRow.uuid.is_(session.user_uuid))
.options(noload("*"), load_only(UserRow.created_at))
)
user_row: UserRow = (await db_session.scalars(user_created_query)).first()
_user_created_at: datetime = user_row.created_at
self.user_created_cache[session.user_uuid] = _user_created_at
user_created_at = self.user_created_cache[session.user_uuid]
grace_period_end = await self._grace_period_checker.get_grace_period_end(
user_created_at
)

if session.access_key not in policy_cache:
query = (
sa.select(
[
keypair_resource_policies.c.max_session_lifetime,
keypair_resource_policies.c.idle_timeout,
]
)
.select_from(
sa.join(
keypairs,
keypair_resource_policies,
keypair_resource_policies.c.name == keypairs.c.resource_policy,
sa.select(KeyPairRow)
.where(KeyPairRow.access_key.is_(session.access_key))
.options(
noload("*"),
joinedload(KeyPairRow.resource_policy_row).options(
load_only(
KeyPairResourcePolicyRow.max_session_lifetime,
KeyPairResourcePolicyRow.idle_timeout,
)
),
)
.where(keypairs.c.access_key == kernel["access_key"])
)
result = await conn.execute(query)
policy = result.first()
assert policy is not None
policy_cache[kernel["access_key"]] = policy
_policy: KeyPairResourcePolicyRow = (await db_session.scalars()).first()
assert _policy is not None
policy_cache[session.access_key] = _policy
policy = policy_cache[session.access_key]

check_task = [
checker.check_idleness(
kernel, conn, policy, self._redis_live, grace_period_end=grace_period_end
session,
db_session,
policy,
self._redis_live,
grace_period_end=grace_period_end,
)
for checker in self._checkers
]
Expand All @@ -324,13 +340,13 @@ async def _do_idle_check(
log.info(
"The {} idle checker triggered termination of s:{}",
checker.name,
kernel["session_id"],
session.id,
)
if not terminated:
terminated = True
await self._event_producer.produce_event(
DoTerminateSessionEvent(
kernel["session_id"],
session.id,
checker.terminate_reason,
),
)
Expand Down Expand Up @@ -419,12 +435,12 @@ class AbstractIdleChecker(metaclass=ABCMeta):
@abstractmethod
async def check_idleness(
self,
kernel: Row,
dbconn: SAConnection,
policy: Row,
session: SessionRow,
db_session: SASession,
policy: KeyPairResourcePolicyRow,
redis_obj: RedisConnectionInfo,
*,
grace_period_end: Optional[datetime] = None,
grace_period_end: datetime | None = None,
) -> bool:
"""
Check the kernel is whether idle or not.
Expand Down Expand Up @@ -473,7 +489,7 @@ async def del_remaining_time_report(

async def get_grace_period_end(
self,
kernel: Row,
user_created_at: datetime,
) -> Optional[datetime]:
"""
Calculate the user's initial grace period for idle checkers.
Expand All @@ -482,7 +498,6 @@ async def get_grace_period_end(
"""
if self.user_initial_grace_period is None:
return None
user_created_at: datetime = kernel["user_created_at"]
return user_created_at + self.user_initial_grace_period

@property
Expand Down Expand Up @@ -621,20 +636,20 @@ async def get_extra_info(self, session_id: SessionId) -> Optional[dict[str, Any]

async def check_idleness(
self,
kernel: Row,
dbconn: SAConnection,
policy: Row,
session: SessionRow,
db_session: SASession,
policy: KeyPairResourcePolicyRow,
redis_obj: RedisConnectionInfo,
*,
grace_period_end: Optional[datetime] = None,
grace_period_end: datetime | None = None,
) -> bool:
"""
Check the kernel is timeout or not.
And save remaining time until timeout of kernel to Redis.
"""
session_id = kernel["session_id"]
session_id: SessionId = session.id

if kernel["session_type"] == SessionTypes.BATCH:
if session.session_type == SessionTypes.BATCH:
return True

active_streams = await redis_helper.execute(
Expand Down Expand Up @@ -701,27 +716,27 @@ async def get_extra_info(self, session_id: SessionId) -> Optional[dict[str, Any]

async def check_idleness(
self,
kernel: Row,
dbconn: SAConnection,
policy: Row,
session: SessionRow,
db_session: SASession,
policy: KeyPairResourcePolicyRow,
redis_obj: RedisConnectionInfo,
*,
grace_period_end: Optional[datetime] = None,
grace_period_end: datetime | None = None,
) -> bool:
"""
Check the kernel has been living longer than resource policy's `max_session_lifetime`.
And save remaining time until `max_session_lifetime` of kernel to Redis.
Check the session has been living longer than resource policy's `max_session_lifetime`.
And save remaining time until `max_session_lifetime` of session to Redis.
"""

session_id = kernel["session_id"]
session_id = session.id
if (max_session_lifetime := policy["max_session_lifetime"]) > 0:
# TODO: once per-status time tracking is implemented, let's change created_at
# to the timestamp when the session entered PREPARING status.
idle_timeout = timedelta(seconds=max_session_lifetime)
now: datetime = await get_db_now(dbconn)
kernel_created_at: datetime = kernel["created_at"]
now: datetime = await get_db_now(db_session)
session_created_at: datetime = session.created_at
remaining = calculate_remaining_time(
now, kernel_created_at, idle_timeout, grace_period_end
now, session_created_at, idle_timeout, grace_period_end
)
await self.set_remaining_time_report(
redis_obj, session_id, remaining if remaining > 0 else IDLE_TIMEOUT_VALUE
Expand Down Expand Up @@ -829,23 +844,23 @@ def get_last_collected_key(self, session_id: SessionId) -> str:

async def check_idleness(
self,
kernel: Row,
dbconn: SAConnection,
policy: Row,
session: SessionRow,
db_session: SASession,
policy: KeyPairResourcePolicyRow,
redis_obj: RedisConnectionInfo,
*,
grace_period_end: Optional[datetime] = None,
grace_period_end: datetime | None = None,
) -> bool:
"""
Check the the average utilization of kernel and whether it exceeds the threshold or not.
And save the average utilization of kernel to Redis.
"""
session_id = kernel["session_id"]
session_id: SessionId = session.id

interval = IdleCheckerHost.check_interval
# time_window: Utilization is calculated within this window.
time_window: timedelta = self.get_time_window(policy)
occupied_slots = kernel["occupied_slots"]
occupied_slots: ResourceSlot = session.occupying_slots
unavailable_resources: Set[str] = set()

util_series_key = f"session.{session_id}.util_series"
Expand All @@ -870,15 +885,15 @@ async def check_idleness(
return True

# Report time remaining until the first time window is full as expire time
db_now: datetime = await get_db_now(dbconn)
kernel_created_at: datetime = kernel["created_at"]
db_now: datetime = await get_db_now(db_session)
session_created_at: datetime = session.created_at
if grace_period_end is not None:
start_from = max(grace_period_end, kernel_created_at)
start_from = max(grace_period_end, session_created_at)
else:
start_from = kernel_created_at
start_from = session_created_at
total_initial_grace_period_end = start_from + self.initial_grace_period
remaining = calculate_remaining_time(
db_now, kernel_created_at, time_window, total_initial_grace_period_end
db_now, session_created_at, time_window, total_initial_grace_period_end
)
await self.set_remaining_time_report(
redis_obj, session_id, remaining if remaining > 0 else IDLE_TIMEOUT_VALUE
Expand All @@ -902,14 +917,7 @@ async def check_idleness(
unavailable_resources.update(self.slot_resource_map[slot])

# Get current utilization data from all containers of the session.
if kernel["cluster_size"] > 1:
query = sa.select([kernels.c.id]).where(
(kernels.c.session_id == session_id) & (kernels.c.status.in_(LIVE_STATUS)),
)
rows = (await dbconn.execute(query)).fetchall()
kernel_ids = [k["id"] for k in rows]
else:
kernel_ids = [kernel["id"]]
kernel_ids = [k.id for k in session.kernels]
current_utilizations = await self.get_current_utilization(kernel_ids, occupied_slots)
if current_utilizations is None:
return True
Expand Down
Loading

0 comments on commit bea0bea

Please sign in to comment.