Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update session's occupying slots when kernel starts #1832

Merged
merged 20 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3ee9c74
fix: update session's occupying slots when kernel starts
fregataa Jan 10, 2024
c76c178
add news fragment
fregataa Jan 10, 2024
7764d92
Merge branch 'main' into fix/session-occupying-slot-update
fregataa Mar 29, 2024
b147512
revert convert_resource_spec_to_resource_slot to return stringified R…
fregataa Mar 29, 2024
f7f03a1
convert value of occupying_slots to Decimal when sum session's all ke…
fregataa Mar 29, 2024
a68f7d0
chore: update GraphQL schema dump
fregataa Mar 29, 2024
dde28a7
classmethod to function
fregataa Mar 29, 2024
2b739f1
Merge branch 'main' into fix/session-occupying-slot-update
fregataa Apr 1, 2024
296bd2e
resolve occupying_slots GQL field from row instead of sibling kernels
fregataa Apr 1, 2024
c0cfd29
add alembic migration to sync occupying_slots
fregataa Apr 1, 2024
d3b615c
Merge branch 'main' into fix/session-occupying-slot-update
fregataa Apr 4, 2024
8591708
Merge branch 'main' into fix/session-occupying-slot-update
fregataa Apr 5, 2024
4013193
update alembic migration
fregataa Apr 5, 2024
1e3879e
Merge branch 'main' into fix/session-occupying-slot-update
fregataa Jun 10, 2024
b710811
update alembic migration
fregataa Jun 10, 2024
d38ef73
update news fragment
fregataa Jun 10, 2024
3176ea1
Merge branch 'main' into fix/session-occupying-slot-update
fregataa Jun 14, 2024
bde8d20
load kernels.occupied_slots only when migration
fregataa Jun 14, 2024
63b9dde
revert usage of SessionRow.finalize_runnning()
fregataa Jun 16, 2024
b8a7530
Merge branch 'main' into fix/session-occupying-slot-update
kyujin-cho Jun 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/1832.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Do not omit to update session's occupying resources to DB when a kernel starts.
5 changes: 5 additions & 0 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,11 @@ def __format__(self, format_spec):


class ResourceSlot(UserDict):
"""
key: `str` type slot name.
value: `str` or `Decimal` type value. Do not convert this to `float` or `int`.
"""

__slots__ = ("data",)

def __init__(self, *args, **kwargs) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""sync_session_occupying_slots_to_sibling_kernels

Revision ID: 679e5721e94d
Revises: f56a82d0ac9f
Create Date: 2024-04-01 17:34:33.480996

"""

import textwrap
from typing import Any, cast

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql as pgsql
from sqlalchemy.orm import Session, load_only, registry, relationship, selectinload
from sqlalchemy.sql import text

from ai.backend.common.types import ResourceSlot
from ai.backend.manager.models.base import GUID, IDColumn, ResourceSlotColumn, convention

# revision identifiers, used by Alembic.
revision = "679e5721e94d"
down_revision = "f56a82d0ac9f"
branch_labels = None
depends_on = None

metadata = sa.MetaData(naming_convention=convention)
mapper_registry = registry(metadata=metadata)
Base: Any = mapper_registry.generate_base()

PAGE_SIZE = 100


class SessionRow(Base):
__tablename__ = "sessions"
__table_args__ = {"extend_existing": True}

id = IDColumn()
cluster_size = sa.Column("cluster_size", sa.Integer, nullable=False, default=1)
starts_at = sa.Column("starts_at", sa.DateTime(timezone=True), nullable=True, default=sa.null())
status_history = sa.Column("status_history", pgsql.JSONB(), nullable=True, default=sa.null())
occupying_slots = sa.Column("occupying_slots", ResourceSlotColumn(), nullable=False)

kernels = relationship("KernelRow")


class KernelRow(Base):
__tablename__ = "kernels"
__table_args__ = {"extend_existing": True}

id = IDColumn()
session_id = sa.Column(
"session_id",
GUID,
sa.ForeignKey("sessions.id"),
unique=False,
index=True,
nullable=False,
)
occupied_slots = sa.Column("occupied_slots", ResourceSlotColumn(), nullable=False)


def _sync_single_kernel_cluster_session():
conn = op.get_bind()
sync_stmt = textwrap.dedent(
"""
UPDATE sessions
SET occupying_slots = kernels.occupied_slots
FROM kernels
WHERE sessions.id = kernels.session_id
AND sessions.cluster_size = 1;
"""
)
conn.execute(text(sync_stmt))


def _sync_multi_kernel_cluster_session():
db_sess = Session(op.get_bind())

while True:
select_stmt = (
sa.select(SessionRow)
.where(
(SessionRow.cluster_size != 1)
& (SessionRow.occupying_slots == {})
& (SessionRow.status_history.op("?")("RUNNING"))
)
.limit(PAGE_SIZE)
.options(selectinload(SessionRow.kernels).options(load_only(KernelRow.occupied_slots)))
)
session_list = cast(list[SessionRow], db_sess.scalars(select_stmt).all())
if not session_list:
return

update_stmt = (
sa.update(SessionRow)
.where(SessionRow.id == sa.bindparam("session_id"))
.values(occupying_slots=sa.bindparam("occupying_slots"))
)
data = []
for session in session_list:
occupying_slots = sum([k.occupied_slots for k in session.kernels], start=ResourceSlot())
data.append({"session_id": session.id, "occupying_slots": occupying_slots})
db_sess.execute(update_stmt, data)


def upgrade():
_sync_single_kernel_cluster_session()
_sync_multi_kernel_cluster_session()


def downgrade():
pass
36 changes: 2 additions & 34 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
from contextlib import asynccontextmanager as actxmgr
from datetime import datetime
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -34,11 +33,9 @@
AccessKey,
ClusterMode,
KernelId,
ResourceSlot,
SessionId,
SessionResult,
SessionTypes,
SlotName,
VFolderMount,
)

Expand Down Expand Up @@ -1263,6 +1260,8 @@ def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]:
"service_ports": row.main_kernel.service_ports,
"mounts": [mount.name for mount in row.vfolder_mounts],
"vfolder_mounts": row.vfolder_mounts,
"occupying_slots": row.occupying_slots.to_json(),
"occupied_slots": row.occupying_slots.to_json(),
"requested_slots": row.requested_slots.to_json(),
# statistics
"num_queries": row.num_queries,
Expand All @@ -1275,23 +1274,6 @@ def from_row(cls, ctx: GraphQueryContext, row: Row) -> ComputeSession | None:
props = cls.parse_row(ctx, row)
return cls(**props)

async def resolve_occupying_slots(self, info: graphene.ResolveInfo) -> Mapping[str, Any]:
"""
Calculate the sum of occupying resource slots of all sub-kernels,
and return the JSON-serializable object from the sum result.
"""
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, "ComputeContainer.by_session")
containers = await loader.load(self.session_id)
zero = ResourceSlot()
return sum(
(
ResourceSlot({SlotName(k): Decimal(v) for k, v in c.occupied_slots.items()})
for c in containers
),
start=zero,
).to_json()

async def resolve_inference_metrics(
self, info: graphene.ResolveInfo
) -> Optional[Mapping[str, Any]]:
Expand All @@ -1301,20 +1283,6 @@ async def resolve_inference_metrics(
)
return await loader.load(self.id)

# legacy
async def resolve_occupied_slots(self, info: graphene.ResolveInfo) -> Mapping[str, Any]:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, "ComputeContainer.by_session")
containers = await loader.load(self.session_id)
zero = ResourceSlot()
return sum(
(
ResourceSlot({SlotName(k): Decimal(v) for k, v in c.occupied_slots.items()})
for c in containers
),
start=zero,
).to_json()

async def resolve_containers(
self,
info: graphene.ResolveInfo,
Expand Down
18 changes: 17 additions & 1 deletion src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
from .models.utils import (
ExtendedAsyncSAEngine,
execute_with_retry,
execute_with_txn_retry,
is_db_retry_error,
reenter_txn,
reenter_txn_session,
Expand Down Expand Up @@ -1581,12 +1582,27 @@ async def finalize_running(
),
}
self._kernel_actual_allocated_resources[kernel_id] = actual_allocs

async def _update_session_occupying_slots(db_session: AsyncSession) -> None:
_stmt = sa.select(SessionRow).where(SessionRow.id == session_id)
session_row = cast(SessionRow | None, await db_session.scalar(_stmt))
if session_row is None:
raise SessionNotFound(f"Failed to fetch session (id:{session_id})")
session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots})
session_occupying_slots.sync_keys(actual_allocs)
for key, val in session_occupying_slots.items():
session_occupying_slots[key] = str(Decimal(val) + Decimal(actual_allocs[key]))
session_row.occupying_slots = session_occupying_slots

async with self.db.connect() as db_conn:
await execute_with_txn_retry(
_update_session_occupying_slots, self.db.begin_session, db_conn
)
kernel_did_update = await KernelRow.update_kernel(
self.db, kernel_id, new_status, update_data=update_data
)
if not kernel_did_update:
fregataa marked this conversation as resolved.
Show resolved Hide resolved
return

new_session_status = await SessionRow.transit_session_status(self.db, session_id)
if new_session_status is None or new_session_status != SessionStatus.RUNNING:
return
Expand Down
Loading