From 031bee9d7da17a00092f36e6bc87d2b940f65766 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Tue, 28 Nov 2023 11:09:10 +0900 Subject: [PATCH] fix: calculate session dependency on session insertion transaction (#1720) Backported-from: main Backported-to: 23.09 --- changes/1720.fix.md | 1 + src/ai/backend/manager/registry.py | 41 ++++++++++++++++-------------- 2 files changed, 23 insertions(+), 19 deletions(-) create mode 100644 changes/1720.fix.md diff --git a/changes/1720.fix.md b/changes/1720.fix.md new file mode 100644 index 0000000000..7298dcda70 --- /dev/null +++ b/changes/1720.fix.md @@ -0,0 +1 @@ +Minimize latency between session insertion and dependency insertion. diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index a6a529ec36..065250a0a5 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -1202,26 +1202,8 @@ async def enqueue_session( async def _enqueue() -> None: async with self.db.begin_session() as db_sess: - if sudo_session_enabled: - environ["SUDO_SESSION_ENABLED"] = "1" - - session_data["environ"] = environ - session_data["requested_slots"] = session_requested_slots - session = SessionRow(**session_data) - kernels = [KernelRow(**kernel) for kernel in kernel_data] - db_sess.add(session) - db_sess.add_all(kernels) - - await execute_with_retry(_enqueue) - - async def _post_enqueue() -> None: - async with self.db.begin_session() as db_sess: - if route_id: - routing_row = await RoutingRow.get(db_sess, route_id) - routing_row.session = session_id - + matched_dependency_session_ids = [] if dependency_sessions: - matched_dependency_session_ids = [] for dependency_id in dependency_sessions: try: match_info = await SessionRow.get_session( @@ -1238,11 +1220,32 @@ async def _post_enqueue() -> None: else: matched_dependency_session_ids.append(match_info.id) + if sudo_session_enabled: + environ["SUDO_SESSION_ENABLED"] = "1" + + session_data["environ"] = environ + session_data["requested_slots"] = session_requested_slots + session = SessionRow(**session_data) + kernels = [KernelRow(**kernel) for kernel in kernel_data] + db_sess.add(session) + db_sess.add_all(kernels) + await db_sess.flush() + + if matched_dependency_session_ids: dependency_rows = [ SessionDependencyRow(session_id=session_id, depends_on=depend_id) for depend_id in matched_dependency_session_ids ] db_sess.add_all(dependency_rows) + + await execute_with_retry(_enqueue) + + async def _post_enqueue() -> None: + async with self.db.begin_session() as db_sess: + if route_id: + routing_row = await RoutingRow.get(db_sess, route_id) + routing_row.session = session_id + await db_sess.commit() await execute_with_retry(_post_enqueue)