diff --git a/src/ai/backend/common/redis_helper.py b/src/ai/backend/common/redis_helper.py index 55a27a7aaf..0883b5a345 100644 --- a/src/ai/backend/common/redis_helper.py +++ b/src/ai/backend/common/redis_helper.py @@ -39,6 +39,7 @@ "read_stream", "read_stream_by_group", "get_redis_object", + "get_redis_now", ) _keepalive_options: MutableMapping[int, int] = {} @@ -538,6 +539,11 @@ def get_redis_object( ) +async def get_redis_now(redis_obj: RedisConnectionInfo) -> float: + t = await execute(redis_obj, lambda r: r.time()) + return t[0] + (t[1] / (10**6)) + + async def ping_redis_connection(redis_client: Redis) -> bool: try: return await redis_client.ping() diff --git a/src/ai/backend/manager/api/stream.py b/src/ai/backend/manager/api/stream.py index 8239e129dc..9d5ae3af85 100644 --- a/src/ai/backend/manager/api/stream.py +++ b/src/ai/backend/manager/api/stream.py @@ -43,7 +43,7 @@ import zmq.asyncio from aiohttp import web from aiotools import adefer -from sqlalchemy.orm import load_only, noload, selectinload +from sqlalchemy.orm import load_only, noload from ai.backend.common import redis_helper from ai.backend.common import validators as tx @@ -52,12 +52,12 @@ from ai.backend.common.types import AccessKey, AgentId, KernelId, SessionId from ai.backend.manager.idle import ( AppStreamingStatus, - get_redis_now, - update_and_check_kernel_activeness, + get_kernel_conn_tracker_key, ) from ..defs import DEFAULT_ROLE -from ..models import KernelLoadingStrategy, KernelRow, SessionRow +from ..models import KernelLoadingStrategy, SessionRow +from ..models.kernel import KernelRow, update_and_check_disconnection from .auth import auth_required from .exceptions import ( AppNotFound, @@ -521,7 +521,7 @@ async def stream_proxy( raise InvalidAPIParameters(f"Unsupported service protocol: {sport['protocol']}") redis_live = root_ctx.redis_live - conn_tracker_key = f"kernel.{kernel_id}.active_app_connections" + conn_tracker_key = get_kernel_conn_tracker_key(kernel_id) conn_tracker_val = f"{kernel_id}:{service}:{stream_id}" _conn_tracker_script = textwrap.dedent(""" @@ -559,7 +559,7 @@ async def refresh_cb( async def add_conn_track() -> None: async with app_ctx.conn_tracker_lock: - app_ctx.active_session_ids[session_id] += 1 + app_ctx.active_kernel_ids[kernel_id] += 1 now = await redis_helper.execute(redis_live, lambda r: r.time()) now = now[0] + (now[1] / (10**6)) await redis_helper.execute( @@ -574,9 +574,9 @@ async def add_conn_track() -> None: async def clear_conn_track() -> None: async with app_ctx.conn_tracker_lock: - app_ctx.active_session_ids[session_id] -= 1 - if app_ctx.active_session_ids[session_id] <= 0: - del app_ctx.active_session_ids[session_id] + app_ctx.active_kernel_ids[kernel_id] -= 1 + if app_ctx.active_kernel_ids[kernel_id] <= 0: + del app_ctx.active_kernel_ids[kernel_id] await redis_helper.execute( redis_live, lambda r: r.zrem(conn_tracker_key, conn_tracker_val) ) @@ -728,35 +728,29 @@ async def stream_conn_tracker_gc(root_ctx: RootContext, app_ctx: PrivateContext) ) else: raise e - session_ids = list(app_ctx.active_session_ids.keys()) + kernel_ids = list(app_ctx.active_kernel_ids.keys()) async with app_ctx.conn_tracker_lock: - session_query = ( - sa.select(SessionRow) - .where(SessionRow.id.in_(session_ids)) + kernel_query = ( + sa.select(KernelRow) + .where(KernelRow.id.in_(kernel_ids)) .options( noload("*"), - selectinload(SessionRow.kernels).options( - noload("*"), - load_only(KernelRow.id, KernelRow.cluster_role), - ), + load_only(KernelRow.id, KernelRow.cluster_role), ) ) async with root_ctx.db.begin_readonly_session() as db_session: - session_rows = (await db_session.scalars(session_query)).all() - now = await get_redis_now(redis_live) - for session in session_rows: - for kernel in session.kernels: - is_active = await update_and_check_kernel_activeness( - redis_live, - kernel, - now, - no_packet_timeout, + kernel_rows: list[KernelRow] = (await db_session.scalars(kernel_query)).all() + for kernel in kernel_rows: + is_dead = await update_and_check_disconnection( + redis_live, + kernel, + no_packet_timeout, + ) + if is_dead: + await root_ctx.idle_checker_host.update_app_streaming_status( + kernel.id, + AppStreamingStatus.NO_ACTIVE_CONNECTIONS, ) - if not is_active: - await root_ctx.idle_checker_host.update_app_streaming_status( - kernel.id, - AppStreamingStatus.NO_ACTIVE_CONNECTIONS, - ) await asyncio.sleep(10) except asyncio.CancelledError: pass @@ -771,7 +765,7 @@ class PrivateContext: zctx: zmq.asyncio.Context conn_tracker_lock: asyncio.Lock conn_tracker_gc_task: asyncio.Task - active_session_ids: DefaultDict[SessionId, int] + active_kernel_ids: DefaultDict[KernelId, int] async def stream_app_ctx(app: web.Application) -> AsyncIterator[None]: @@ -784,7 +778,7 @@ async def stream_app_ctx(app: web.Application) -> AsyncIterator[None]: app_ctx.stream_stdin_socks = defaultdict(weakref.WeakSet) app_ctx.zctx = zmq.asyncio.Context() app_ctx.conn_tracker_lock = asyncio.Lock() - app_ctx.active_session_ids = defaultdict(int) # multiset[int] + app_ctx.active_kernel_ids = defaultdict(int) # multiset[int] app_ctx.conn_tracker_gc_task = asyncio.create_task(stream_conn_tracker_gc(root_ctx, app_ctx)) root_ctx.event_dispatcher.subscribe(KernelTerminatingEvent, app, handle_kernel_terminating) diff --git a/src/ai/backend/manager/idle.py b/src/ai/backend/manager/idle.py index 5552570baa..2d6006f778 100644 --- a/src/ai/backend/manager/idle.py +++ b/src/ai/backend/manager/idle.py @@ -66,11 +66,12 @@ from ai.backend.common.utils import nmget from .defs import DEFAULT_ROLE, LockID -from .models.kernel import KernelRow +from .models.kernel import KernelRow, get_kernel_conn_tracker_key from .models.keypair import KeyPairRow from .models.resource_policy import KeyPairResourcePolicyRow from .models.session import SessionRow, SessionStatus from .models.user import UserRow +from .models.utils import get_db_now from .types import DistributedLockFactory if TYPE_CHECKING: @@ -114,52 +115,6 @@ def calculate_remaining_time( return remaining.total_seconds() -async def get_redis_now(redis_obj: RedisConnectionInfo) -> float: - t = await redis_helper.execute(redis_obj, lambda r: r.time()) - return t[0] + (t[1] / (10**6)) - - -async def get_db_now(db_session: SASession) -> datetime: - return await db_session.scalar(sa.select(sa.func.now())) - - -def get_kernel_conn_tracker_key(kernel_id: KernelId) -> str: - return f"kernel.{kernel_id}.active_app_connections" - - -async def update_and_check_kernel_activeness( - redis_conn: RedisConnectionInfo, - kernel: KernelRow, - current_time: float, - timeout: timedelta, -) -> bool: - """ - Update kernel's activeness. - Return True if the kernel has any active connection. - """ - conn_tracker_key = get_kernel_conn_tracker_key(kernel.id) - prev_remaining_count = await redis_helper.execute( - redis_conn, - lambda r: r.zcount(conn_tracker_key, float("-inf"), float("+inf")), - ) - removed_count = await redis_helper.execute( - redis_conn, - lambda r: r.zremrangebyscore( - conn_tracker_key, - float("-inf"), - current_time - timeout.total_seconds(), - ), - ) - remaining_count = await redis_helper.execute( - redis_conn, - lambda r: r.zcount(conn_tracker_key, float("-inf"), float("+inf")), - ) - log.debug( - f"conn_tracker: gc {kernel.id} removed/remaining = {removed_count}/{remaining_count}", - ) - return not (prev_remaining_count > 0 and remaining_count == 0) - - async def get_session_activeness( redis_conn: RedisConnectionInfo, session: SessionRow ) -> float | None: @@ -756,7 +711,7 @@ async def check_idleness( active_streams = await get_session_activeness(self._redis_live, session) if active_streams is not None and active_streams > 0: return True - now: float = await get_redis_now(self._redis_live) + now: float = await redis_helper.get_redis_now(self._redis_live) kernel_remaining: dict[KernelId, KernelRemainingTimeData] = {} async def _check_kernel_idleness(kernel: KernelRow) -> None: diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index 09f996c5bf..349e8b1f6d 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -5,7 +5,7 @@ import logging import uuid from contextlib import asynccontextmanager as actxmgr -from datetime import datetime +from datetime import datetime, timedelta from typing import ( TYPE_CHECKING, Any, @@ -372,6 +372,54 @@ async def handle_kernel_exception( raise +def get_kernel_conn_tracker_key(kernel_id: KernelId) -> str: + return f"kernel.{kernel_id}.active_app_connections" + + +async def get_remaining_connection( + redis_conn: RedisConnectionInfo, + kernel: KernelRow, +) -> int: + return await redis_helper.execute( + redis_conn, + lambda r: r.zcount( + get_kernel_conn_tracker_key(kernel), + float("-inf"), + float("+inf"), + ), + ) + + +async def update_and_check_disconnection( + redis_conn: RedisConnectionInfo, + kernel: KernelRow, + timeout: timedelta, +) -> bool: + """ + Update kernel's activeness. + Return True if the kernel has any disconnection. + """ + conn_tracker_key = get_kernel_conn_tracker_key(kernel.id) + prev_remaining_count = await get_remaining_connection(redis_conn, kernel) + current_time = await redis_helper.get_redis_now(redis_conn) + removed_count = await redis_helper.execute( + redis_conn, + lambda r: r.zremrangebyscore( + conn_tracker_key, + float("-inf"), + current_time - timeout.total_seconds(), + ), + ) + remaining_count = await redis_helper.execute( + redis_conn, + lambda r: r.zcount(conn_tracker_key, float("-inf"), float("+inf")), + ) + log.debug( + f"conn_tracker: gc {kernel.id} removed/remaining = {removed_count}/{remaining_count}", + ) + return prev_remaining_count > 0 and remaining_count == 0 + + kernels = sa.Table( "kernels", mapper_registry.metadata, diff --git a/src/ai/backend/manager/models/utils.py b/src/ai/backend/manager/models/utils.py index f8a1afb65f..73e201cb2c 100644 --- a/src/ai/backend/manager/models/utils.py +++ b/src/ai/backend/manager/models/utils.py @@ -5,6 +5,7 @@ import json import logging from contextlib import asynccontextmanager as actxmgr +from datetime import datetime from typing import TYPE_CHECKING, Any, AsyncIterator, Awaitable, Callable, Mapping, Tuple, TypeVar from urllib.parse import quote_plus as urlquote @@ -358,3 +359,7 @@ def agg_to_str(column: sa.Column) -> sa.sql.functions.Function: def agg_to_array(column: sa.Column) -> sa.sql.functions.Function: return sa.func.array_agg(psql.aggregate_order_by(column, column.asc())) + + +async def get_db_now(db_session: SASession) -> datetime: + return await db_session.scalar(sa.select(sa.func.now()))