diff --git a/src/ai/backend/manager/idle.py b/src/ai/backend/manager/idle.py index 5f4cfab7d9..c7e7c740ff 100644 --- a/src/ai/backend/manager/idle.py +++ b/src/ai/backend/manager/idle.py @@ -4,6 +4,7 @@ import enum import logging import math +import uuid from abc import ABCMeta, abstractmethod from collections import UserDict, defaultdict from datetime import datetime, timedelta @@ -31,6 +32,7 @@ import trafaret as t from aiotools import TaskGroupError from sqlalchemy.engine import Row +from sqlalchemy.orm import joinedload, load_only, noload, selectinload import ai.backend.common.validators as tx from ai.backend.common import msgpack, redis_helper @@ -55,19 +57,21 @@ AccessKey, BinarySize, RedisConnectionInfo, + ResourceSlot, SessionTypes, ) from ai.backend.common.utils import nmget -from .defs import DEFAULT_ROLE, LockID -from .models.kernel import LIVE_STATUS, kernels -from .models.keypair import keypairs -from .models.resource_policy import keypair_resource_policies -from .models.user import users +from .defs import LockID +from .models.kernel import KernelRow +from .models.keypair import KeyPairRow +from .models.resource_policy import KeyPairResourcePolicyRow +from .models.session import SessionRow, SessionStatus +from .models.user import UserRow from .types import DistributedLockFactory if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection + from sqlalchemy.ext.asyncio import AsyncSession as SASession from ai.backend.common.types import AgentId, KernelId, SessionId @@ -112,8 +116,8 @@ async def get_redis_now(redis_obj: RedisConnectionInfo) -> float: return t[0] + (t[1] / (10**6)) -async def get_db_now(dbconn: SAConnection) -> datetime: - return await dbconn.scalar(sa.select(sa.func.now())) +async def get_db_now(db_session: SASession) -> datetime: + return await db_session.scalar(sa.select(sa.func.now())) class UtilizationExtraInfo(NamedTuple): @@ -197,6 +201,7 @@ def __init__( self._grace_period_checker: NewUserGracePeriodChecker = NewUserGracePeriodChecker( event_dispatcher, self._redis_live, self._redis_stat ) + self.user_created_cache: dict[uuid.UUID, datetime] = {} def add_checker(self, checker: BaseIdleChecker): if self._frozen: @@ -253,59 +258,70 @@ async def _do_idle_check( event: DoIdleCheckEvent, ) -> None: log.debug("do_idle_check(): triggered") - policy_cache: dict[AccessKey, Row] = {} - async with self._db.begin_readonly() as conn: - j = sa.join(kernels, users, kernels.c.user_uuid == users.c.uuid) - query = ( - sa.select( - [ - kernels.c.id, - kernels.c.access_key, - kernels.c.session_id, - kernels.c.session_type, - kernels.c.created_at, - kernels.c.occupied_slots, - kernels.c.cluster_size, - users.c.created_at.label("user_created_at"), - ] - ) - .select_from(j) - .where( - (kernels.c.status.in_(LIVE_STATUS)) - & (kernels.c.cluster_role == DEFAULT_ROLE) - & (kernels.c.session_type != SessionTypes.INFERENCE), - ) + policy_cache: dict[AccessKey, KeyPairResourcePolicyRow] = {} + query = ( + sa.select(SessionRow) + .select_from(SessionRow) + .where( + (SessionRow.status.is_(SessionStatus.RUNNING)) + & (SessionRow.session_type.is_not(SessionTypes.INFERENCE)) + ) + .options( + noload("*"), + load_only( + SessionRow.id, + SessionRow.access_key, + SessionRow.session_type, + SessionRow.created_at, + SessionRow.occupying_slots, + SessionRow.user_uuid, + ), + selectinload(SessionRow.kernels).options(noload("*"), load_only(KernelRow.id)), ) - result = await conn.execute(query) - rows = result.fetchall() - for kernel in rows: - grace_period_end = await self._grace_period_checker.get_grace_period_end(kernel) - policy = policy_cache.get(kernel["access_key"], None) - if policy is None: + ) + async with self._db.begin_readonly_session() as db_session: + session_rows: list[SessionRow] = (await db_session.scalars(query)).all() + for session in session_rows: + if session.user_uuid not in self.user_created_cache: + user_created_query = ( + sa.select(UserRow) + .where(UserRow.uuid.is_(session.user_uuid)) + .options(noload("*"), load_only(UserRow.created_at)) + ) + user_row: UserRow = (await db_session.scalars(user_created_query)).first() + _user_created_at: datetime = user_row.created_at + self.user_created_cache[session.user_uuid] = _user_created_at + user_created_at = self.user_created_cache[session.user_uuid] + grace_period_end = await self._grace_period_checker.get_grace_period_end( + user_created_at + ) + + if session.access_key not in policy_cache: query = ( - sa.select( - [ - keypair_resource_policies.c.max_session_lifetime, - keypair_resource_policies.c.idle_timeout, - ] - ) - .select_from( - sa.join( - keypairs, - keypair_resource_policies, - keypair_resource_policies.c.name == keypairs.c.resource_policy, + sa.select(KeyPairRow) + .where(KeyPairRow.access_key.is_(session.access_key)) + .options( + noload("*"), + joinedload(KeyPairRow.resource_policy_row).options( + load_only( + KeyPairResourcePolicyRow.max_session_lifetime, + KeyPairResourcePolicyRow.idle_timeout, + ) ), ) - .where(keypairs.c.access_key == kernel["access_key"]) ) - result = await conn.execute(query) - policy = result.first() - assert policy is not None - policy_cache[kernel["access_key"]] = policy + _policy: KeyPairResourcePolicyRow = (await db_session.scalars()).first() + assert _policy is not None + policy_cache[session.access_key] = _policy + policy = policy_cache[session.access_key] check_task = [ checker.check_idleness( - kernel, conn, policy, self._redis_live, grace_period_end=grace_period_end + session, + db_session, + policy, + self._redis_live, + grace_period_end=grace_period_end, ) for checker in self._checkers ] @@ -324,13 +340,13 @@ async def _do_idle_check( log.info( "The {} idle checker triggered termination of s:{}", checker.name, - kernel["session_id"], + session.id, ) if not terminated: terminated = True await self._event_producer.produce_event( DoTerminateSessionEvent( - kernel["session_id"], + session.id, checker.terminate_reason, ), ) @@ -419,12 +435,12 @@ class AbstractIdleChecker(metaclass=ABCMeta): @abstractmethod async def check_idleness( self, - kernel: Row, - dbconn: SAConnection, - policy: Row, + session: SessionRow, + db_session: SASession, + policy: KeyPairResourcePolicyRow, redis_obj: RedisConnectionInfo, *, - grace_period_end: Optional[datetime] = None, + grace_period_end: datetime | None = None, ) -> bool: """ Check the kernel is whether idle or not. @@ -473,7 +489,7 @@ async def del_remaining_time_report( async def get_grace_period_end( self, - kernel: Row, + user_created_at: datetime, ) -> Optional[datetime]: """ Calculate the user's initial grace period for idle checkers. @@ -482,7 +498,6 @@ async def get_grace_period_end( """ if self.user_initial_grace_period is None: return None - user_created_at: datetime = kernel["user_created_at"] return user_created_at + self.user_initial_grace_period @property @@ -621,20 +636,20 @@ async def get_extra_info(self, session_id: SessionId) -> Optional[dict[str, Any] async def check_idleness( self, - kernel: Row, - dbconn: SAConnection, - policy: Row, + session: SessionRow, + db_session: SASession, + policy: KeyPairResourcePolicyRow, redis_obj: RedisConnectionInfo, *, - grace_period_end: Optional[datetime] = None, + grace_period_end: datetime | None = None, ) -> bool: """ Check the kernel is timeout or not. And save remaining time until timeout of kernel to Redis. """ - session_id = kernel["session_id"] + session_id: SessionId = session.id - if kernel["session_type"] == SessionTypes.BATCH: + if session.session_type == SessionTypes.BATCH: return True active_streams = await redis_helper.execute( @@ -701,27 +716,27 @@ async def get_extra_info(self, session_id: SessionId) -> Optional[dict[str, Any] async def check_idleness( self, - kernel: Row, - dbconn: SAConnection, - policy: Row, + session: SessionRow, + db_session: SASession, + policy: KeyPairResourcePolicyRow, redis_obj: RedisConnectionInfo, *, - grace_period_end: Optional[datetime] = None, + grace_period_end: datetime | None = None, ) -> bool: """ - Check the kernel has been living longer than resource policy's `max_session_lifetime`. - And save remaining time until `max_session_lifetime` of kernel to Redis. + Check the session has been living longer than resource policy's `max_session_lifetime`. + And save remaining time until `max_session_lifetime` of session to Redis. """ - session_id = kernel["session_id"] + session_id = session.id if (max_session_lifetime := policy["max_session_lifetime"]) > 0: # TODO: once per-status time tracking is implemented, let's change created_at # to the timestamp when the session entered PREPARING status. idle_timeout = timedelta(seconds=max_session_lifetime) - now: datetime = await get_db_now(dbconn) - kernel_created_at: datetime = kernel["created_at"] + now: datetime = await get_db_now(db_session) + session_created_at: datetime = session.created_at remaining = calculate_remaining_time( - now, kernel_created_at, idle_timeout, grace_period_end + now, session_created_at, idle_timeout, grace_period_end ) await self.set_remaining_time_report( redis_obj, session_id, remaining if remaining > 0 else IDLE_TIMEOUT_VALUE @@ -829,23 +844,23 @@ def get_last_collected_key(self, session_id: SessionId) -> str: async def check_idleness( self, - kernel: Row, - dbconn: SAConnection, - policy: Row, + session: SessionRow, + db_session: SASession, + policy: KeyPairResourcePolicyRow, redis_obj: RedisConnectionInfo, *, - grace_period_end: Optional[datetime] = None, + grace_period_end: datetime | None = None, ) -> bool: """ Check the the average utilization of kernel and whether it exceeds the threshold or not. And save the average utilization of kernel to Redis. """ - session_id = kernel["session_id"] + session_id: SessionId = session.id interval = IdleCheckerHost.check_interval # time_window: Utilization is calculated within this window. time_window: timedelta = self.get_time_window(policy) - occupied_slots = kernel["occupied_slots"] + occupied_slots: ResourceSlot = session.occupying_slots unavailable_resources: Set[str] = set() util_series_key = f"session.{session_id}.util_series" @@ -870,15 +885,15 @@ async def check_idleness( return True # Report time remaining until the first time window is full as expire time - db_now: datetime = await get_db_now(dbconn) - kernel_created_at: datetime = kernel["created_at"] + db_now: datetime = await get_db_now(db_session) + session_created_at: datetime = session.created_at if grace_period_end is not None: - start_from = max(grace_period_end, kernel_created_at) + start_from = max(grace_period_end, session_created_at) else: - start_from = kernel_created_at + start_from = session_created_at total_initial_grace_period_end = start_from + self.initial_grace_period remaining = calculate_remaining_time( - db_now, kernel_created_at, time_window, total_initial_grace_period_end + db_now, session_created_at, time_window, total_initial_grace_period_end ) await self.set_remaining_time_report( redis_obj, session_id, remaining if remaining > 0 else IDLE_TIMEOUT_VALUE @@ -902,14 +917,7 @@ async def check_idleness( unavailable_resources.update(self.slot_resource_map[slot]) # Get current utilization data from all containers of the session. - if kernel["cluster_size"] > 1: - query = sa.select([kernels.c.id]).where( - (kernels.c.session_id == session_id) & (kernels.c.status.in_(LIVE_STATUS)), - ) - rows = (await dbconn.execute(query)).fetchall() - kernel_ids = [k["id"] for k in rows] - else: - kernel_ids = [kernel["id"]] + kernel_ids = [k.id for k in session.kernels] current_utilizations = await self.get_current_utilization(kernel_ids, occupied_slots) if current_utilizations is None: return True diff --git a/tests/manager/test_idle_checker.py b/tests/manager/test_idle_checker.py index 2174ebd981..cbde623d01 100644 --- a/tests/manager/test_idle_checker.py +++ b/tests/manager/test_idle_checker.py @@ -17,6 +17,8 @@ calculate_remaining_time, init_idle_checkers, ) +from ai.backend.manager.models.resource_policy import KeyPairResourcePolicyRow +from ai.backend.manager.models.session import SessionRow from ai.backend.manager.server import ( background_task_ctx, database_ctx, @@ -108,7 +110,6 @@ async def new_user_grace_period_checker( }, "enabled": "", } - kernel = {"user_created_at": user_created_at} await root_ctx.shared_config.etcd.put_prefix("config/idle", idle_value) # type: ignore[arg-type] checker_host = await init_idle_checkers( @@ -120,7 +121,9 @@ async def new_user_grace_period_checker( ) try: await checker_host.start() - grace_period_end = await checker_host._grace_period_checker.get_grace_period_end(kernel) + grace_period_end = await checker_host._grace_period_checker.get_grace_period_end( + user_created_at + ) finally: await checker_host.shutdown() @@ -163,13 +166,8 @@ async def network_timeout_idle_checker( }, "enabled": "network_timeout,", } - kernel = { - "session_id": session_id, - "session_type": SessionTypes.INTERACTIVE, - } - policy = { - "idle_timeout": threshold, - } + session = SessionRow(id=session_id, session_type=SessionTypes.INTERACTIVE) + policy = KeyPairResourcePolicyRow(idle_timeout=threshold) await root_ctx.shared_config.etcd.put_prefix("config/idle", idle_value) # type: ignore[arg-type] checker_host = await init_idle_checkers( @@ -189,7 +187,7 @@ async def network_timeout_idle_checker( ) should_alive = await network_idle_checker.check_idleness( - kernel, checker_host._db, policy, checker_host._redis_live + session, checker_host._db, policy, checker_host._redis_live ) remaining = await network_idle_checker.get_checker_result( checker_host._redis_live, session_id @@ -216,13 +214,8 @@ async def network_timeout_idle_checker( }, "enabled": "network_timeout,", } - kernel = { - "session_id": session_id, - "session_type": SessionTypes.INTERACTIVE, - } - policy = { - "idle_timeout": threshold, - } + session = SessionRow(id=session_id, session_type=SessionTypes.INTERACTIVE) + policy = KeyPairResourcePolicyRow(idle_timeout=threshold) await root_ctx.shared_config.etcd.put_prefix("config/idle", idle_value) # type: ignore[arg-type] checker_host = await init_idle_checkers( @@ -242,7 +235,7 @@ async def network_timeout_idle_checker( ) should_alive = await network_idle_checker.check_idleness( - kernel, checker_host._db, policy, checker_host._redis_live + session, checker_host._db, policy, checker_host._redis_live ) remaining = await network_idle_checker.get_checker_result( checker_host._redis_live, session_id @@ -272,14 +265,8 @@ async def network_timeout_idle_checker( }, "enabled": "network_timeout,", } - kernel = { - "session_id": session_id, - "session_type": SessionTypes.INTERACTIVE, - "user_created_at": user_created_at, - } - policy = { - "idle_timeout": threshold, - } + session = SessionRow(id=session_id, session_type=SessionTypes.INTERACTIVE) + policy = KeyPairResourcePolicyRow(idle_timeout=threshold) await root_ctx.shared_config.etcd.put_prefix("config/idle", idle_value) # type: ignore[arg-type] checker_host = await init_idle_checkers( @@ -298,9 +285,11 @@ async def network_timeout_idle_checker( lambda r: r.set(f"session.{session_id}.last_access", last_access), ) - grace_period_end = await checker_host._grace_period_checker.get_grace_period_end(kernel) + grace_period_end = await checker_host._grace_period_checker.get_grace_period_end( + user_created_at + ) should_alive = await network_idle_checker.check_idleness( - kernel, + session, checker_host._db, policy, checker_host._redis_live, @@ -334,14 +323,8 @@ async def network_timeout_idle_checker( }, "enabled": "network_timeout,", } - kernel = { - "session_id": session_id, - "session_type": SessionTypes.INTERACTIVE, - "user_created_at": user_created_at, - } - policy = { - "idle_timeout": threshold, - } + session = SessionRow(id=session_id, session_type=SessionTypes.INTERACTIVE) + policy = KeyPairResourcePolicyRow(idle_timeout=threshold) await root_ctx.shared_config.etcd.put_prefix("config/idle", idle_value) # type: ignore[arg-type] checker_host = await init_idle_checkers( @@ -360,9 +343,11 @@ async def network_timeout_idle_checker( lambda r: r.set(f"session.{session_id}.last_access", last_access), ) - grace_period_end = await checker_host._grace_period_checker.get_grace_period_end(kernel) + grace_period_end = await checker_host._grace_period_checker.get_grace_period_end( + user_created_at + ) should_alive = await network_idle_checker.check_idleness( - kernel, + session, checker_host._db, policy, checker_host._redis_live, @@ -401,7 +386,7 @@ async def session_lifetime_checker( # test 1 # remaining time is positive and no grace period session_id = SessionId(uuid4()) - kernel_created_at = datetime(2020, 3, 1, 12, 30, second=0) + session_created_at = datetime(2020, 3, 1, 12, 30, second=0) max_session_lifetime = 30 now = datetime(2020, 3, 1, 12, 30, second=10) mocker.patch("ai.backend.manager.idle.get_db_now", return_value=now) @@ -410,13 +395,8 @@ async def session_lifetime_checker( "checkers": {}, "enabled": "", } - kernel = { - "session_id": session_id, - "created_at": kernel_created_at, - } - policy = { - "max_session_lifetime": max_session_lifetime, - } + session = SessionRow(id=session_id, created_at=session_created_at) + policy = KeyPairResourcePolicyRow(max_session_lifetime=max_session_lifetime) await root_ctx.shared_config.etcd.put_prefix("config/idle", idle_value) # type: ignore[arg-type] checker_host = await init_idle_checkers( @@ -431,7 +411,7 @@ async def session_lifetime_checker( session_lifetime_checker = get_checker_from_host(checker_host, SessionLifetimeChecker) should_alive = await session_lifetime_checker.check_idleness( - kernel, + session, checker_host._db, policy, checker_host._redis_live, @@ -448,7 +428,7 @@ async def session_lifetime_checker( # test 2 # remaining time is negative and no grace period session_id = SessionId(uuid4()) - kernel_created_at = datetime(2020, 3, 1, 12, 30, second=0) + session_created_at = datetime(2020, 3, 1, 12, 30, second=0) max_session_lifetime = 30 now = datetime(2020, 3, 1, 12, 30, second=50) mocker.patch("ai.backend.manager.idle.get_db_now", return_value=now) @@ -457,13 +437,8 @@ async def session_lifetime_checker( "checkers": {}, "enabled": "", } - kernel = { - "session_id": session_id, - "created_at": kernel_created_at, - } - policy = { - "max_session_lifetime": max_session_lifetime, - } + session = SessionRow(id=session_id, created_at=session_created_at) + policy = KeyPairResourcePolicyRow(max_session_lifetime=max_session_lifetime) await root_ctx.shared_config.etcd.put_prefix("config/idle", idle_value) # type: ignore[arg-type] checker_host = await init_idle_checkers( @@ -478,7 +453,7 @@ async def session_lifetime_checker( session_lifetime_checker = get_checker_from_host(checker_host, SessionLifetimeChecker) should_alive = await session_lifetime_checker.check_idleness( - kernel, + session, checker_host._db, policy, checker_host._redis_live, @@ -495,7 +470,7 @@ async def session_lifetime_checker( # test 3 # remaining time is positive with new user grace period session_id = SessionId(uuid4()) - kernel_created_at = datetime(2020, 3, 1, 12, 30, second=10) + session_created_at = datetime(2020, 3, 1, 12, 30, second=10) user_created_at = datetime(2020, 3, 1, 12, 30, second=0) max_session_lifetime = 10 now = datetime(2020, 3, 1, 12, 30, second=25) @@ -508,14 +483,8 @@ async def session_lifetime_checker( }, "enabled": "", } - kernel = { - "session_id": session_id, - "created_at": kernel_created_at, - "user_created_at": user_created_at, - } - policy = { - "max_session_lifetime": max_session_lifetime, - } + session = SessionRow(id=session_id, created_at=session_created_at) + policy = KeyPairResourcePolicyRow(max_session_lifetime=max_session_lifetime) await root_ctx.shared_config.etcd.put_prefix("config/idle", idle_value) # type: ignore[arg-type] checker_host = await init_idle_checkers( @@ -528,10 +497,12 @@ async def session_lifetime_checker( try: await checker_host.start() session_lifetime_checker = get_checker_from_host(checker_host, SessionLifetimeChecker) - grace_period_end = await checker_host._grace_period_checker.get_grace_period_end(kernel) + grace_period_end = await checker_host._grace_period_checker.get_grace_period_end( + user_created_at + ) should_alive = await session_lifetime_checker.check_idleness( - kernel, + session, checker_host._db, policy, checker_host._redis_live, @@ -549,7 +520,7 @@ async def session_lifetime_checker( # test 4 # remaining time is negative with new user grace period session_id = SessionId(uuid4()) - kernel_created_at = datetime(2020, 3, 1, 12, 30, second=40) + session_created_at = datetime(2020, 3, 1, 12, 30, second=40) user_created_at = datetime(2020, 3, 1, 12, 30, second=0) max_session_lifetime = 10 now = datetime(2020, 3, 1, 12, 30, second=55) @@ -562,14 +533,8 @@ async def session_lifetime_checker( }, "enabled": "", } - kernel = { - "session_id": session_id, - "created_at": kernel_created_at, - "user_created_at": user_created_at, - } - policy = { - "max_session_lifetime": max_session_lifetime, - } + session = SessionRow(id=session_id, created_at=session_created_at) + policy = KeyPairResourcePolicyRow(max_session_lifetime=max_session_lifetime) await root_ctx.shared_config.etcd.put_prefix("config/idle", idle_value) # type: ignore[arg-type] checker_host = await init_idle_checkers( @@ -582,10 +547,12 @@ async def session_lifetime_checker( try: await checker_host.start() session_lifetime_checker = get_checker_from_host(checker_host, SessionLifetimeChecker) - grace_period_end = await checker_host._grace_period_checker.get_grace_period_end(kernel) + grace_period_end = await checker_host._grace_period_checker.get_grace_period_end( + user_created_at + ) should_alive = await session_lifetime_checker.check_idleness( - kernel, + session, checker_host._db, policy, checker_host._redis_live, @@ -712,7 +679,7 @@ async def utilization_idle_checker( kernel_id = KernelId(uuid4()) timewindow = 30 initial_grace_period = 100 - kernel_created_at = datetime(2020, 3, 1, 12, 30, second=0) + session_created_at = datetime(2020, 3, 1, 12, 30, second=0) now = datetime(2020, 3, 1, 12, 30, second=10) expected = timedelta(seconds=120).total_seconds() @@ -730,16 +697,10 @@ async def utilization_idle_checker( "pct": "10.0", }, } - kernel = { - "id": kernel_id, - "session_id": session_id, - "created_at": kernel_created_at, - "cluster_size": 1, - "occupied_slots": occupied_slots, - } - policy = { - "idle_timeout": timewindow, - } + session = SessionRow( + id=session_id, created_at=session_created_at, occupying_slots=occupied_slots + ) + policy = KeyPairResourcePolicyRow(idle_timeout=timewindow) resource_thresholds = { "cpu_util": {"average": "0"}, @@ -775,7 +736,7 @@ async def utilization_idle_checker( utilization_idle_checker = get_checker_from_host(checker_host, UtilizationIdleChecker) should_alive = await utilization_idle_checker.check_idleness( - kernel, checker_host._db, policy, checker_host._redis_live + session, checker_host._db, policy, checker_host._redis_live ) remaining = await utilization_idle_checker.get_checker_result( checker_host._redis_live, session_id @@ -792,7 +753,7 @@ async def utilization_idle_checker( # remaining time is positive with utilization. session_id = SessionId(uuid4()) kernel_id = KernelId(uuid4()) - kernel_created_at = datetime(2020, 3, 1, 12, 30, second=0) + datetime(2020, 3, 1, 12, 30, second=0) now = datetime(2020, 3, 1, 12, 30, second=10) initial_grace_period = 0 timewindow = 15 @@ -812,16 +773,10 @@ async def utilization_idle_checker( "pct": "10.0", }, } - kernel = { - "id": kernel_id, - "session_id": session_id, - "created_at": kernel_created_at, - "cluster_size": 1, - "occupied_slots": occupied_slots, - } - policy = { - "idle_timeout": timewindow, - } + session = SessionRow( + id=session_id, created_at=session_created_at, occupying_slots=occupied_slots + ) + policy = KeyPairResourcePolicyRow(idle_timeout=timewindow) resource_thresholds = { "cpu_util": {"average": "0"}, @@ -857,7 +812,7 @@ async def utilization_idle_checker( utilization_idle_checker = get_checker_from_host(checker_host, UtilizationIdleChecker) should_alive = await utilization_idle_checker.check_idleness( - kernel, checker_host._db, policy, checker_host._redis_live + session, checker_host._db, policy, checker_host._redis_live ) remaining = await utilization_idle_checker.get_checker_result( checker_host._redis_live, session_id @@ -876,7 +831,7 @@ async def utilization_idle_checker( kernel_id = KernelId(uuid4()) timewindow = 15 initial_grace_period = 0 - kernel_created_at = datetime(2020, 3, 1, 12, 30, second=0) + datetime(2020, 3, 1, 12, 30, second=0) now = datetime(2020, 3, 1, 12, 30, second=50) expected = -1 @@ -894,16 +849,10 @@ async def utilization_idle_checker( "pct": "10.0", }, } - kernel = { - "id": kernel_id, - "session_id": session_id, - "created_at": kernel_created_at, - "cluster_size": 1, - "occupied_slots": occupied_slots, - } - policy = { - "idle_timeout": timewindow, - } + session = SessionRow( + id=session_id, created_at=session_created_at, occupying_slots=occupied_slots + ) + policy = KeyPairResourcePolicyRow(idle_timeout=timewindow) resource_thresholds = { "cpu_util": {"average": "0"}, @@ -939,7 +888,7 @@ async def utilization_idle_checker( utilization_idle_checker = get_checker_from_host(checker_host, UtilizationIdleChecker) should_alive = await utilization_idle_checker.check_idleness( - kernel, checker_host._db, policy, checker_host._redis_live + session, checker_host._db, policy, checker_host._redis_live ) remaining = await utilization_idle_checker.get_checker_result( checker_host._redis_live, session_id