diff --git a/changes/2275.feature.md b/changes/2275.feature.md new file mode 100644 index 0000000000..2c17a3465c --- /dev/null +++ b/changes/2275.feature.md @@ -0,0 +1 @@ +Allow superadmins to force-update session status through destroy API. diff --git a/src/ai/backend/client/cli/session/lifecycle.py b/src/ai/backend/client/cli/session/lifecycle.py index ec5975b73f..e7e5322fce 100644 --- a/src/ai/backend/client/cli/session/lifecycle.py +++ b/src/ai/backend/client/cli/session/lifecycle.py @@ -568,6 +568,12 @@ def destroy(session_names, forced, owner, stats, recursive): else: if not has_failure: print_done("Done.") + if forced: + print_warn( + "If you have destroyed a session whose status is one of " + "[`PULLING`, `SCHEDULED`, `PREPARING`, `TERMINATING`, `ERROR`], " + "Manual cleanup of actual containers may be required." + ) if stats: stats = ret.get("stats", None) if ret else None if stats: diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 41c7be6112..2148ffd664 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -1346,8 +1346,9 @@ async def rename_session(request: web.Request, params: Any) -> web.Response: async def destroy(request: web.Request, params: Any) -> web.Response: root_ctx: RootContext = request.app["_root.context"] session_name = request.match_info["session_name"] + user_role = cast(UserRole, request["user"]["role"]) requester_access_key, owner_access_key = await get_access_key_scopes(request, params) - if requester_access_key != owner_access_key and request["user"]["role"] not in ( + if requester_access_key != owner_access_key and user_role not in ( UserRole.ADMIN, UserRole.SUPERADMIN, ): @@ -1395,7 +1396,9 @@ async def destroy(request: web.Request, params: Any) -> web.Response: last_stats = await asyncio.gather( *[ - root_ctx.registry.destroy_session(sess, forced=params["forced"]) + root_ctx.registry.destroy_session( + sess, forced=params["forced"], user_role=user_role + ) for sess in sessions if isinstance(sess, SessionRow) ], @@ -1420,6 +1423,7 @@ async def destroy(request: web.Request, params: Any) -> web.Response: last_stat = await root_ctx.registry.destroy_session( session, forced=params["forced"], + user_role=user_role, ) resp = { "stats": last_stat, diff --git a/src/ai/backend/manager/models/agent.py b/src/ai/backend/manager/models/agent.py index 88cff60670..b4c9347f7c 100644 --- a/src/ai/backend/manager/models/agent.py +++ b/src/ai/backend/manager/models/agent.py @@ -2,7 +2,7 @@ import enum import uuid -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, cast import graphene import sqlalchemy as sa @@ -12,7 +12,7 @@ from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection from sqlalchemy.ext.asyncio import AsyncSession as SASession -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, selectinload, with_loader_criteria from sqlalchemy.sql.expression import false, true from ai.backend.common import msgpack, redis_helper @@ -32,7 +32,7 @@ simple_db_mutate, ) from .group import association_groups_users -from .kernel import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, kernels +from .kernel import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, KernelRow, kernels from .keypair import keypairs from .minilang.ordering import OrderSpecItem, QueryOrderParser from .minilang.queryfilter import FieldSpecItem, QueryFilterParser, enum_field_getter @@ -41,6 +41,7 @@ if TYPE_CHECKING: from ai.backend.manager.models.gql import GraphQueryContext + __all__: Sequence[str] = ( "agents", "AgentRow", @@ -619,6 +620,28 @@ async def recalc_agent_resource_occupancy(db_conn: SAConnection, agent_id: Agent await db_conn.execute(query) +async def recalc_agent_resource_occupancy_using_orm( + db_session: SASession, agent_id: AgentId +) -> None: + agent_query = ( + sa.select(AgentRow) + .where(AgentRow.id == agent_id) + .options( + selectinload(AgentRow.kernels), + with_loader_criteria( + KernelRow, KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES) + ), + ) + ) + occupied_slots = ResourceSlot() + agent_row = cast(AgentRow, await db_session.scalar(agent_query)) + kernel_rows = cast(list[KernelRow], agent_row.kernels) + for kernel in kernel_rows: + if kernel.status in AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES: + occupied_slots += kernel.occupied_slots + agent_row.occupied_slots = occupied_slots + + class ModifyAgent(graphene.Mutation): allowed_roles = (UserRole.SUPERADMIN,) diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index fc7f7a14ed..6eebeb4e55 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -3,17 +3,16 @@ import asyncio import enum import logging +from collections.abc import Iterable, Mapping, Sequence from contextlib import asynccontextmanager as actxmgr +from dataclasses import dataclass, field from datetime import datetime from typing import ( TYPE_CHECKING, Any, AsyncIterator, - Iterable, List, - Mapping, Optional, - Sequence, Union, ) from uuid import UUID @@ -507,6 +506,31 @@ async def _match_sessions_by_name( return result.scalars().all() +COMPUTE_CONCURRENCY_USED_KEY_PREFIX = "keypair.concurrency_used." +SYSTEM_CONCURRENCY_USED_KEY_PREFIX = "keypair.sftp_concurrency_used." + + +@dataclass +class ConcurrencyUsed: + access_key: AccessKey + compute_session_ids: set[SessionId] = field(default_factory=set) + system_session_ids: set[SessionId] = field(default_factory=set) + + @property + def compute_concurrency_used_key(self) -> str: + return f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" + + @property + def system_concurrency_used_key(self) -> str: + return f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}" + + def to_cnt_map(self) -> Mapping[str, int]: + return { + self.compute_concurrency_used_key: len(self.compute_concurrency_used_key), + self.system_concurrency_used_key: len(self.system_concurrency_used_key), + } + + class SessionOp(enum.StrEnum): CREATE = "create_session" DESTROY = "destroy_session" diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 456959a5dc..3427285bf6 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -20,6 +20,7 @@ Any, Dict, List, + Literal, Mapping, MutableMapping, Optional, @@ -150,6 +151,7 @@ SessionDependencyRow, SessionRow, SessionStatus, + UserRole, UserRow, agents, domains, @@ -164,6 +166,12 @@ scaling_groups, verify_vfolder_name, ) +from .models.session import ( + COMPUTE_CONCURRENCY_USED_KEY_PREFIX, + SESSION_KERNEL_STATUS_MAPPING, + SYSTEM_CONCURRENCY_USED_KEY_PREFIX, + ConcurrencyUsed, +) from .models.utils import ( ExtendedAsyncSAEngine, execute_with_retry, @@ -1957,17 +1965,11 @@ async def _update_agent_resource() -> None: await execute_with_retry(_update_agent_resource) async def recalc_resource_usage(self, do_fullscan: bool = False) -> None: - concurrency_used_per_key: MutableMapping[str, set] = defaultdict( - set - ) # key: access_key, value: set of session_id - sftp_concurrency_used_per_key: MutableMapping[str, set] = defaultdict( - set - ) # key: access_key, value: set of session_id - - async def _recalc() -> None: + async def _recalc() -> Mapping[AccessKey, ConcurrencyUsed]: occupied_slots_per_agent: MutableMapping[str, ResourceSlot] = defaultdict( lambda: ResourceSlot({"cpu": 0, "mem": 0}) ) + access_key_to_concurrency_used: dict[AccessKey, ConcurrencyUsed] = {} async with self.db.begin_session() as db_sess: # Query running containers and calculate concurrency_used per AK and @@ -1998,12 +2000,19 @@ async def _recalc() -> None: kernel.occupied_slots ) if session_status in USER_RESOURCE_OCCUPYING_SESSION_STATUSES: + access_key = cast(AccessKey, session_row.access_key) + if access_key not in access_key_to_concurrency_used: + access_key_to_concurrency_used[access_key] = ConcurrencyUsed( + access_key + ) if kernel.role in PRIVATE_KERNEL_ROLES: - sftp_concurrency_used_per_key[session_row.access_key].add( + access_key_to_concurrency_used[access_key].system_session_ids.add( session_row.id ) else: - concurrency_used_per_key[session_row.access_key].add(session_row.id) + access_key_to_concurrency_used[access_key].compute_session_ids.add( + session_row.id + ) if len(occupied_slots_per_agent) > 0: # Update occupied_slots for agents with running containers. @@ -2033,54 +2042,54 @@ async def _recalc() -> None: .where(AgentRow.status == AgentStatus.ALIVE) ) await db_sess.execute(query) + return access_key_to_concurrency_used - await execute_with_retry(_recalc) + access_key_to_concurrency_used = await execute_with_retry(_recalc) # Update keypair resource usage for keypairs with running containers. - kp_key = "keypair.concurrency_used" - sftp_kp_key = "keypair.sftp_concurrency_used" - async def _update(r: Redis): - updates = { - f"{kp_key}.{ak}": len(session_ids) - for ak, session_ids in concurrency_used_per_key.items() - } | { - f"{sftp_kp_key}.{ak}": len(session_ids) - for ak, session_ids in sftp_concurrency_used_per_key.items() - } + updates: dict[str, int] = {} + for concurrency in access_key_to_concurrency_used.values(): + updates |= concurrency.to_cnt_map() if updates: await r.mset(typing.cast(MSetType, updates)) async def _update_by_fullscan(r: Redis): updates = {} - keys = await r.keys(f"{kp_key}.*") + keys = await r.keys(f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}*") for stat_key in keys: if isinstance(stat_key, bytes): _stat_key = stat_key.decode("utf-8") else: - _stat_key = stat_key - ak = _stat_key.replace(f"{kp_key}.", "") - session_concurrency = concurrency_used_per_key.get(ak) - usage = len(session_concurrency) if session_concurrency is not None else 0 + _stat_key = cast(str, stat_key) + ak = _stat_key.replace(COMPUTE_CONCURRENCY_USED_KEY_PREFIX, "") + concurrent_sessions = access_key_to_concurrency_used.get(AccessKey(ak)) + usage = ( + len(concurrent_sessions.compute_session_ids) + if concurrent_sessions is not None + else 0 + ) updates[_stat_key] = usage - keys = await r.keys(f"{sftp_kp_key}.*") + keys = await r.keys(f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}*") for stat_key in keys: if isinstance(stat_key, bytes): _stat_key = stat_key.decode("utf-8") else: - _stat_key = stat_key - ak = _stat_key.replace(f"{sftp_kp_key}.", "") - session_concurrency = sftp_concurrency_used_per_key.get(ak) - usage = len(session_concurrency) if session_concurrency is not None else 0 + _stat_key = cast(str, stat_key) + ak = _stat_key.replace(SYSTEM_CONCURRENCY_USED_KEY_PREFIX, "") + concurrent_sessions = access_key_to_concurrency_used.get(AccessKey(ak)) + usage = ( + len(concurrent_sessions.system_concurrency_used_key) + if concurrent_sessions is not None + else 0 + ) updates[_stat_key] = usage if updates: await r.mset(typing.cast(MSetType, updates)) # Do full scan if the entire system does not have ANY sessions/sftp-sessions # to set all concurrency_used to 0 - _do_fullscan = do_fullscan or ( - not concurrency_used_per_key and not sftp_concurrency_used_per_key - ) + _do_fullscan = do_fullscan or not access_key_to_concurrency_used if _do_fullscan: await redis_helper.execute( self.redis_stat, @@ -2138,6 +2147,7 @@ async def destroy_session( *, forced: bool = False, reason: Optional[KernelLifecycleEventReason] = None, + user_role: UserRole | None = None, ) -> Mapping[str, Any]: """ Destroy session kernels. Do not destroy @@ -2162,6 +2172,50 @@ async def destroy_session( if hook_result.status != PASSED: raise RejectedByHook.from_hook_result(hook_result) + async def _force_destroy_for_suadmin( + target_status: Literal[SessionStatus.CANCELLED, SessionStatus.TERMINATED], + ) -> None: + current_time = datetime.now(tzutc()) + destroy_reason = str(KernelLifecycleEventReason.FORCE_TERMINATED) + + async def _destroy(db_session: AsyncSession) -> SessionRow: + _stmt = ( + sa.select(SessionRow) + .where(SessionRow.id == session_id) + .options(selectinload(SessionRow.kernels)) + ) + session_row = cast(SessionRow | None, await db_session.scalar(_stmt)) + if session_row is None: + raise SessionNotFound(f"Session not found (id: {session_id})") + kernel_rows = cast(list[KernelRow], session_row.kernels) + kernel_target_status = SESSION_KERNEL_STATUS_MAPPING[target_status] + for kern in kernel_rows: + kern.status = kernel_target_status + kern.terminated_at = current_time + kern.status_info = destroy_reason + kern.status_history = sql_json_merge( + KernelRow.status_history, + (), + { + kernel_target_status.name: current_time.isoformat(), + }, + ) + session_row.status = target_status + session_row.terminated_at = current_time + session_row.status_info = destroy_reason + session_row.status_history = sql_json_merge( + SessionRow.status_history, + (), + { + target_status.name: current_time.isoformat(), + }, + ) + return session_row + + async with self.db.connect() as db_conn: + await execute_with_txn_retry(_destroy, self.db.begin_session, db_conn) + await self.recalc_resource_usage() + async with handle_session_exception( self.db, "destroy_session", @@ -2200,6 +2254,17 @@ async def destroy_session( self.db, session_id, SessionStatus.CANCELLED ) case SessionStatus.PULLING: + # Exceptionally allow superadmins to destroy PULLING sessions. + # Clients should be informed that they have to handle the containers destroyed here. + # TODO: detach image-pull process from kernel-start process and allow all users to destroy PULLING sessions. + if forced and user_role == UserRole.SUPERADMIN: + log.warning( + "force-terminating session (s:{}, status:{})", + session_id, + target_session.status, + ) + await _force_destroy_for_suadmin(SessionStatus.CANCELLED) + return {} raise GenericForbidden("Cannot destroy sessions in pulling status") case ( SessionStatus.SCHEDULED @@ -2217,12 +2282,18 @@ async def destroy_session( session_id, target_session.status, ) - await SessionRow.set_session_status( - self.db, session_id, SessionStatus.TERMINATING - ) - await self.event_producer.produce_event( - SessionTerminatingEvent(session_id, reason), - ) + if user_role == UserRole.SUPERADMIN: + # Exceptionally let superadmins set the session status to 'TERMINATED' and finish the function. + # TODO: refactor Session/Kernel status management and remove this. + await _force_destroy_for_suadmin(SessionStatus.TERMINATED) + return {} + else: + await SessionRow.set_session_status( + self.db, session_id, SessionStatus.TERMINATING + ) + await self.event_producer.produce_event( + SessionTerminatingEvent(session_id, reason), + ) case SessionStatus.TERMINATED: raise GenericForbidden( "Cannot destroy sessions that has already been already terminated"