Skip to content

Commit

Permalink
feat: Allow superadmins to force-update session status by session des…
Browse files Browse the repository at this point in the history
…troy API (#2275)
  • Loading branch information
fregataa authored Jul 3, 2024
1 parent e4ca40e commit 6307df1
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 48 deletions.
1 change: 1 addition & 0 deletions changes/2275.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow superadmins to force-update session status through destroy API.
6 changes: 6 additions & 0 deletions src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,12 @@ def destroy(session_names, forced, owner, stats, recursive):
else:
if not has_failure:
print_done("Done.")
if forced:
print_warn(
"If you have destroyed a session whose status is one of "
"[`PULLING`, `SCHEDULED`, `PREPARING`, `TERMINATING`, `ERROR`], "
"Manual cleanup of actual containers may be required."
)
if stats:
stats = ret.get("stats", None) if ret else None
if stats:
Expand Down
8 changes: 6 additions & 2 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,8 +1346,9 @@ async def rename_session(request: web.Request, params: Any) -> web.Response:
async def destroy(request: web.Request, params: Any) -> web.Response:
root_ctx: RootContext = request.app["_root.context"]
session_name = request.match_info["session_name"]
user_role = cast(UserRole, request["user"]["role"])
requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
if requester_access_key != owner_access_key and request["user"]["role"] not in (
if requester_access_key != owner_access_key and user_role not in (
UserRole.ADMIN,
UserRole.SUPERADMIN,
):
Expand Down Expand Up @@ -1395,7 +1396,9 @@ async def destroy(request: web.Request, params: Any) -> web.Response:

last_stats = await asyncio.gather(
*[
root_ctx.registry.destroy_session(sess, forced=params["forced"])
root_ctx.registry.destroy_session(
sess, forced=params["forced"], user_role=user_role
)
for sess in sessions
if isinstance(sess, SessionRow)
],
Expand All @@ -1420,6 +1423,7 @@ async def destroy(request: web.Request, params: Any) -> web.Response:
last_stat = await root_ctx.registry.destroy_session(
session,
forced=params["forced"],
user_role=user_role,
)
resp = {
"stats": last_stat,
Expand Down
29 changes: 26 additions & 3 deletions src/ai/backend/manager/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import enum
import uuid
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, cast

import graphene
import sqlalchemy as sa
Expand All @@ -12,7 +12,7 @@
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, selectinload, with_loader_criteria
from sqlalchemy.sql.expression import false, true

from ai.backend.common import msgpack, redis_helper
Expand All @@ -32,7 +32,7 @@
simple_db_mutate,
)
from .group import association_groups_users
from .kernel import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, kernels
from .kernel import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, KernelRow, kernels
from .keypair import keypairs
from .minilang.ordering import OrderSpecItem, QueryOrderParser
from .minilang.queryfilter import FieldSpecItem, QueryFilterParser, enum_field_getter
Expand All @@ -41,6 +41,7 @@
if TYPE_CHECKING:
from ai.backend.manager.models.gql import GraphQueryContext


__all__: Sequence[str] = (
"agents",
"AgentRow",
Expand Down Expand Up @@ -619,6 +620,28 @@ async def recalc_agent_resource_occupancy(db_conn: SAConnection, agent_id: Agent
await db_conn.execute(query)


async def recalc_agent_resource_occupancy_using_orm(
db_session: SASession, agent_id: AgentId
) -> None:
agent_query = (
sa.select(AgentRow)
.where(AgentRow.id == agent_id)
.options(
selectinload(AgentRow.kernels),
with_loader_criteria(
KernelRow, KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)
),
)
)
occupied_slots = ResourceSlot()
agent_row = cast(AgentRow, await db_session.scalar(agent_query))
kernel_rows = cast(list[KernelRow], agent_row.kernels)
for kernel in kernel_rows:
if kernel.status in AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES:
occupied_slots += kernel.occupied_slots
agent_row.occupied_slots = occupied_slots


class ModifyAgent(graphene.Mutation):
allowed_roles = (UserRole.SUPERADMIN,)

Expand Down
30 changes: 27 additions & 3 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import asyncio
import enum
import logging
from collections.abc import Iterable, Mapping, Sequence
from contextlib import asynccontextmanager as actxmgr
from dataclasses import dataclass, field
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterable,
List,
Mapping,
Optional,
Sequence,
Union,
)
from uuid import UUID
Expand Down Expand Up @@ -507,6 +506,31 @@ async def _match_sessions_by_name(
return result.scalars().all()


COMPUTE_CONCURRENCY_USED_KEY_PREFIX = "keypair.concurrency_used."
SYSTEM_CONCURRENCY_USED_KEY_PREFIX = "keypair.sftp_concurrency_used."


@dataclass
class ConcurrencyUsed:
access_key: AccessKey
compute_session_ids: set[SessionId] = field(default_factory=set)
system_session_ids: set[SessionId] = field(default_factory=set)

@property
def compute_concurrency_used_key(self) -> str:
return f"{COMPUTE_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}"

@property
def system_concurrency_used_key(self) -> str:
return f"{SYSTEM_CONCURRENCY_USED_KEY_PREFIX}{self.access_key}"

def to_cnt_map(self) -> Mapping[str, int]:
return {
self.compute_concurrency_used_key: len(self.compute_concurrency_used_key),
self.system_concurrency_used_key: len(self.system_concurrency_used_key),
}


class SessionOp(enum.StrEnum):
CREATE = "create_session"
DESTROY = "destroy_session"
Expand Down
Loading

0 comments on commit 6307df1

Please sign in to comment.