Skip to content

Commit

Permalink
fix: calculate session dependency on session insertion transaction (#…
Browse files Browse the repository at this point in the history
…1720)

Backported-from: main
Backported-to: 23.09
  • Loading branch information
fregataa committed Dec 11, 2023
1 parent 8f27839 commit 031bee9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
1 change: 1 addition & 0 deletions changes/1720.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Minimize latency between session insertion and dependency insertion.
41 changes: 22 additions & 19 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,26 +1202,8 @@ async def enqueue_session(

async def _enqueue() -> None:
async with self.db.begin_session() as db_sess:
if sudo_session_enabled:
environ["SUDO_SESSION_ENABLED"] = "1"

session_data["environ"] = environ
session_data["requested_slots"] = session_requested_slots
session = SessionRow(**session_data)
kernels = [KernelRow(**kernel) for kernel in kernel_data]
db_sess.add(session)
db_sess.add_all(kernels)

await execute_with_retry(_enqueue)

async def _post_enqueue() -> None:
async with self.db.begin_session() as db_sess:
if route_id:
routing_row = await RoutingRow.get(db_sess, route_id)
routing_row.session = session_id

matched_dependency_session_ids = []
if dependency_sessions:
matched_dependency_session_ids = []
for dependency_id in dependency_sessions:
try:
match_info = await SessionRow.get_session(
Expand All @@ -1238,11 +1220,32 @@ async def _post_enqueue() -> None:
else:
matched_dependency_session_ids.append(match_info.id)

if sudo_session_enabled:
environ["SUDO_SESSION_ENABLED"] = "1"

session_data["environ"] = environ
session_data["requested_slots"] = session_requested_slots
session = SessionRow(**session_data)
kernels = [KernelRow(**kernel) for kernel in kernel_data]
db_sess.add(session)
db_sess.add_all(kernels)
await db_sess.flush()

if matched_dependency_session_ids:
dependency_rows = [
SessionDependencyRow(session_id=session_id, depends_on=depend_id)
for depend_id in matched_dependency_session_ids
]
db_sess.add_all(dependency_rows)

await execute_with_retry(_enqueue)

async def _post_enqueue() -> None:
async with self.db.begin_session() as db_sess:
if route_id:
routing_row = await RoutingRow.get(db_sess, route_id)
routing_row.session = session_id

await db_sess.commit()

await execute_with_retry(_post_enqueue)
Expand Down

0 comments on commit 031bee9

Please sign in to comment.