Skip to content

Commit

Permalink
track active kernels instead of sessions and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Oct 23, 2023
1 parent fe14158 commit 13d15c3
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 82 deletions.
6 changes: 6 additions & 0 deletions src/ai/backend/common/redis_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"read_stream",
"read_stream_by_group",
"get_redis_object",
"get_redis_now",
)

_keepalive_options: MutableMapping[int, int] = {}
Expand Down Expand Up @@ -538,6 +539,11 @@ def get_redis_object(
)


async def get_redis_now(redis_obj: RedisConnectionInfo) -> float:
t = await execute(redis_obj, lambda r: r.time())
return t[0] + (t[1] / (10**6))


async def ping_redis_connection(redis_client: Redis) -> bool:
try:
return await redis_client.ping()
Expand Down
60 changes: 27 additions & 33 deletions src/ai/backend/manager/api/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import zmq.asyncio
from aiohttp import web
from aiotools import adefer
from sqlalchemy.orm import load_only, noload, selectinload
from sqlalchemy.orm import load_only, noload

from ai.backend.common import redis_helper
from ai.backend.common import validators as tx
Expand All @@ -52,12 +52,12 @@
from ai.backend.common.types import AccessKey, AgentId, KernelId, SessionId
from ai.backend.manager.idle import (
AppStreamingStatus,
get_redis_now,
update_and_check_kernel_activeness,
get_kernel_conn_tracker_key,
)

from ..defs import DEFAULT_ROLE
from ..models import KernelLoadingStrategy, KernelRow, SessionRow
from ..models import KernelLoadingStrategy, SessionRow
from ..models.kernel import KernelRow, update_and_check_disconnection
from .auth import auth_required
from .exceptions import (
AppNotFound,
Expand Down Expand Up @@ -521,7 +521,7 @@ async def stream_proxy(
raise InvalidAPIParameters(f"Unsupported service protocol: {sport['protocol']}")

redis_live = root_ctx.redis_live
conn_tracker_key = f"kernel.{kernel_id}.active_app_connections"
conn_tracker_key = get_kernel_conn_tracker_key(kernel_id)
conn_tracker_val = f"{kernel_id}:{service}:{stream_id}"

_conn_tracker_script = textwrap.dedent("""
Expand Down Expand Up @@ -559,7 +559,7 @@ async def refresh_cb(

async def add_conn_track() -> None:
async with app_ctx.conn_tracker_lock:
app_ctx.active_session_ids[session_id] += 1
app_ctx.active_kernel_ids[kernel_id] += 1
now = await redis_helper.execute(redis_live, lambda r: r.time())
now = now[0] + (now[1] / (10**6))
await redis_helper.execute(
Expand All @@ -574,9 +574,9 @@ async def add_conn_track() -> None:

async def clear_conn_track() -> None:
async with app_ctx.conn_tracker_lock:
app_ctx.active_session_ids[session_id] -= 1
if app_ctx.active_session_ids[session_id] <= 0:
del app_ctx.active_session_ids[session_id]
app_ctx.active_kernel_ids[kernel_id] -= 1
if app_ctx.active_kernel_ids[kernel_id] <= 0:
del app_ctx.active_kernel_ids[kernel_id]
await redis_helper.execute(
redis_live, lambda r: r.zrem(conn_tracker_key, conn_tracker_val)
)
Expand Down Expand Up @@ -728,35 +728,29 @@ async def stream_conn_tracker_gc(root_ctx: RootContext, app_ctx: PrivateContext)
)
else:
raise e
session_ids = list(app_ctx.active_session_ids.keys())
kernel_ids = list(app_ctx.active_kernel_ids.keys())
async with app_ctx.conn_tracker_lock:
session_query = (
sa.select(SessionRow)
.where(SessionRow.id.in_(session_ids))
kernel_query = (
sa.select(KernelRow)
.where(KernelRow.id.in_(kernel_ids))
.options(
noload("*"),
selectinload(SessionRow.kernels).options(
noload("*"),
load_only(KernelRow.id, KernelRow.cluster_role),
),
load_only(KernelRow.id, KernelRow.cluster_role),
)
)
async with root_ctx.db.begin_readonly_session() as db_session:
session_rows = (await db_session.scalars(session_query)).all()
now = await get_redis_now(redis_live)
for session in session_rows:
for kernel in session.kernels:
is_active = await update_and_check_kernel_activeness(
redis_live,
kernel,
now,
no_packet_timeout,
kernel_rows: list[KernelRow] = (await db_session.scalars(kernel_query)).all()
for kernel in kernel_rows:
is_dead = await update_and_check_disconnection(
redis_live,
kernel,
no_packet_timeout,
)
if is_dead:
await root_ctx.idle_checker_host.update_app_streaming_status(
kernel.id,
AppStreamingStatus.NO_ACTIVE_CONNECTIONS,
)
if not is_active:
await root_ctx.idle_checker_host.update_app_streaming_status(
kernel.id,
AppStreamingStatus.NO_ACTIVE_CONNECTIONS,
)
await asyncio.sleep(10)
except asyncio.CancelledError:
pass
Expand All @@ -771,7 +765,7 @@ class PrivateContext:
zctx: zmq.asyncio.Context
conn_tracker_lock: asyncio.Lock
conn_tracker_gc_task: asyncio.Task
active_session_ids: DefaultDict[SessionId, int]
active_kernel_ids: DefaultDict[KernelId, int]


async def stream_app_ctx(app: web.Application) -> AsyncIterator[None]:
Expand All @@ -784,7 +778,7 @@ async def stream_app_ctx(app: web.Application) -> AsyncIterator[None]:
app_ctx.stream_stdin_socks = defaultdict(weakref.WeakSet)
app_ctx.zctx = zmq.asyncio.Context()
app_ctx.conn_tracker_lock = asyncio.Lock()
app_ctx.active_session_ids = defaultdict(int) # multiset[int]
app_ctx.active_kernel_ids = defaultdict(int) # multiset[int]
app_ctx.conn_tracker_gc_task = asyncio.create_task(stream_conn_tracker_gc(root_ctx, app_ctx))

root_ctx.event_dispatcher.subscribe(KernelTerminatingEvent, app, handle_kernel_terminating)
Expand Down
51 changes: 3 additions & 48 deletions src/ai/backend/manager/idle.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@
from ai.backend.common.utils import nmget

from .defs import DEFAULT_ROLE, LockID
from .models.kernel import KernelRow
from .models.kernel import KernelRow, get_kernel_conn_tracker_key
from .models.keypair import KeyPairRow
from .models.resource_policy import KeyPairResourcePolicyRow
from .models.session import SessionRow, SessionStatus
from .models.user import UserRow
from .models.utils import get_db_now
from .types import DistributedLockFactory

if TYPE_CHECKING:
Expand Down Expand Up @@ -114,52 +115,6 @@ def calculate_remaining_time(
return remaining.total_seconds()


async def get_redis_now(redis_obj: RedisConnectionInfo) -> float:
t = await redis_helper.execute(redis_obj, lambda r: r.time())
return t[0] + (t[1] / (10**6))


async def get_db_now(db_session: SASession) -> datetime:
return await db_session.scalar(sa.select(sa.func.now()))


def get_kernel_conn_tracker_key(kernel_id: KernelId) -> str:
return f"kernel.{kernel_id}.active_app_connections"


async def update_and_check_kernel_activeness(
redis_conn: RedisConnectionInfo,
kernel: KernelRow,
current_time: float,
timeout: timedelta,
) -> bool:
"""
Update kernel's activeness.
Return True if the kernel has any active connection.
"""
conn_tracker_key = get_kernel_conn_tracker_key(kernel.id)
prev_remaining_count = await redis_helper.execute(
redis_conn,
lambda r: r.zcount(conn_tracker_key, float("-inf"), float("+inf")),
)
removed_count = await redis_helper.execute(
redis_conn,
lambda r: r.zremrangebyscore(
conn_tracker_key,
float("-inf"),
current_time - timeout.total_seconds(),
),
)
remaining_count = await redis_helper.execute(
redis_conn,
lambda r: r.zcount(conn_tracker_key, float("-inf"), float("+inf")),
)
log.debug(
f"conn_tracker: gc {kernel.id} removed/remaining = {removed_count}/{remaining_count}",
)
return not (prev_remaining_count > 0 and remaining_count == 0)


async def get_session_activeness(
redis_conn: RedisConnectionInfo, session: SessionRow
) -> float | None:
Expand Down Expand Up @@ -756,7 +711,7 @@ async def check_idleness(
active_streams = await get_session_activeness(self._redis_live, session)
if active_streams is not None and active_streams > 0:
return True
now: float = await get_redis_now(self._redis_live)
now: float = await redis_helper.get_redis_now(self._redis_live)
kernel_remaining: dict[KernelId, KernelRemainingTimeData] = {}

async def _check_kernel_idleness(kernel: KernelRow) -> None:
Expand Down
50 changes: 49 additions & 1 deletion src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import uuid
from contextlib import asynccontextmanager as actxmgr
from datetime import datetime
from datetime import datetime, timedelta
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -372,6 +372,54 @@ async def handle_kernel_exception(
raise


def get_kernel_conn_tracker_key(kernel_id: KernelId) -> str:
return f"kernel.{kernel_id}.active_app_connections"


async def get_remaining_connection(
redis_conn: RedisConnectionInfo,
kernel: KernelRow,
) -> int:
return await redis_helper.execute(
redis_conn,
lambda r: r.zcount(
get_kernel_conn_tracker_key(kernel),
float("-inf"),
float("+inf"),
),
)


async def update_and_check_disconnection(
redis_conn: RedisConnectionInfo,
kernel: KernelRow,
timeout: timedelta,
) -> bool:
"""
Update kernel's activeness.
Return True if the kernel has any disconnection.
"""
conn_tracker_key = get_kernel_conn_tracker_key(kernel.id)
prev_remaining_count = await get_remaining_connection(redis_conn, kernel)
current_time = await redis_helper.get_redis_now(redis_conn)
removed_count = await redis_helper.execute(
redis_conn,
lambda r: r.zremrangebyscore(
conn_tracker_key,
float("-inf"),
current_time - timeout.total_seconds(),
),
)
remaining_count = await redis_helper.execute(
redis_conn,
lambda r: r.zcount(conn_tracker_key, float("-inf"), float("+inf")),
)
log.debug(
f"conn_tracker: gc {kernel.id} removed/remaining = {removed_count}/{remaining_count}",
)
return prev_remaining_count > 0 and remaining_count == 0


kernels = sa.Table(
"kernels",
mapper_registry.metadata,
Expand Down
5 changes: 5 additions & 0 deletions src/ai/backend/manager/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
from contextlib import asynccontextmanager as actxmgr
from datetime import datetime
from typing import TYPE_CHECKING, Any, AsyncIterator, Awaitable, Callable, Mapping, Tuple, TypeVar
from urllib.parse import quote_plus as urlquote

Expand Down Expand Up @@ -358,3 +359,7 @@ def agg_to_str(column: sa.Column) -> sa.sql.functions.Function:

def agg_to_array(column: sa.Column) -> sa.sql.functions.Function:
return sa.func.array_agg(psql.aggregate_order_by(column, column.asc()))


async def get_db_now(db_session: SASession) -> datetime:
return await db_session.scalar(sa.select(sa.func.now()))

0 comments on commit 13d15c3

Please sign in to comment.