Skip to content

Commit

Permalink
refactor: test_bgtask (#3232)
Browse files Browse the repository at this point in the history
Backported-from: main (24.12)
Backported-to: 24.03
Backport-of: 3232
  • Loading branch information
achimnol committed Dec 17, 2024
1 parent 423b1a3 commit 7ae3152
Showing 1 changed file with 61 additions and 48 deletions.
109 changes: 61 additions & 48 deletions tests/manager/api/test_bgtask.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import asyncio
import enum
from collections.abc import AsyncIterator
from typing import Any, TypeAlias

import attr
import pytest
from aiohttp import web

from ai.backend.common import redis_helper
from ai.backend.common.bgtask import BackgroundTaskManager
from ai.backend.common.events import (
BgtaskDoneEvent,
BgtaskFailedEvent,
Expand All @@ -19,24 +22,44 @@
from ai.backend.manager.server import background_task_ctx, event_dispatcher_ctx, shared_config_ctx


@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_background_task(etcd_fixture, create_app_and_client) -> None:
class ContextSentinel(enum.Enum):
TOKEN = enum.auto()


BgtaskFixture: TypeAlias = tuple[BackgroundTaskManager, EventProducer, EventDispatcher]


@pytest.fixture
async def bgtask_fixture(etcd_fixture, create_app_and_client) -> AsyncIterator[BgtaskFixture]:
app, client = await create_app_and_client(
[shared_config_ctx, event_dispatcher_ctx, background_task_ctx],
[".events"],
)
root_ctx: RootContext = app["_root.context"]
producer: EventProducer = root_ctx.event_producer
dispatcher: EventDispatcher = root_ctx.event_dispatcher
update_handler_ctx = {}
done_handler_ctx = {}

yield root_ctx.background_task_manager, producer, dispatcher

await root_ctx.background_task_manager.shutdown()
await producer.close()
await dispatcher.close()
await redis_helper.execute(producer.redis_client, lambda r: r.flushdb())


@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_background_task(bgtask_fixture: BgtaskFixture) -> None:
background_task_manager, producer, dispatcher = bgtask_fixture
update_handler_ctx: dict[str, Any] = {}
done_handler_ctx: dict[str, Any] = {}

async def update_sub(
context: web.Application,
context: ContextSentinel,
source: AgentId,
event: BgtaskUpdatedEvent,
) -> None:
update_handler_ctx["context"] = context
# Copy the arguments to the uppser scope
# since assertions inside the handler does not affect the test result
# because the handlers are executed inside a separate asyncio task.
Expand All @@ -46,10 +69,11 @@ async def update_sub(
update_handler_ctx.update(**update_body)

async def done_sub(
context: web.Application,
context: ContextSentinel,
source: AgentId,
event: BgtaskDoneEvent,
) -> None:
done_handler_ctx["context"] = context
done_handler_ctx["event_name"] = event.name
update_body = attr.asdict(event) # type: ignore
done_handler_ctx.update(**update_body)
Expand All @@ -62,46 +86,38 @@ async def _mock_task(reporter):
await reporter.update(1, message="BGTask ex2")
return "hooray"

dispatcher.subscribe(BgtaskUpdatedEvent, app, update_sub)
dispatcher.subscribe(BgtaskDoneEvent, app, done_sub)
task_id = await root_ctx.background_task_manager.start(_mock_task, name="MockTask1234")
dispatcher.subscribe(BgtaskUpdatedEvent, ContextSentinel.TOKEN, update_sub)
dispatcher.subscribe(BgtaskDoneEvent, ContextSentinel.TOKEN, done_sub)
task_id = await background_task_manager.start(_mock_task, name="MockTask1234")
await asyncio.sleep(2)

try:
assert update_handler_ctx["task_id"] == task_id
assert update_handler_ctx["event_name"] == "bgtask_updated"
assert update_handler_ctx["total_progress"] == 2
assert update_handler_ctx["message"] in ["BGTask ex1", "BGTask ex2"]
if update_handler_ctx["message"] == "BGTask ex1":
assert update_handler_ctx["current_progress"] == 1
else:
assert update_handler_ctx["current_progress"] == 2
assert done_handler_ctx["task_id"] == task_id
assert done_handler_ctx["event_name"] == "bgtask_done"
assert done_handler_ctx["message"] == "hooray"
finally:
await redis_helper.execute(producer.redis_client, lambda r: r.flushdb())
await producer.close()
await dispatcher.close()
assert update_handler_ctx["context"] is ContextSentinel.TOKEN
assert update_handler_ctx["task_id"] == task_id
assert update_handler_ctx["event_name"] == "bgtask_updated"
assert update_handler_ctx["total_progress"] == 2
assert update_handler_ctx["message"] in ["BGTask ex1", "BGTask ex2"]
if update_handler_ctx["message"] == "BGTask ex1":
assert update_handler_ctx["current_progress"] == 1
else:
assert update_handler_ctx["current_progress"] == 2
assert done_handler_ctx["context"] is ContextSentinel.TOKEN
assert done_handler_ctx["task_id"] == task_id
assert done_handler_ctx["event_name"] == "bgtask_done"
assert done_handler_ctx["message"] == "hooray"


@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_background_task_fail(etcd_fixture, create_app_and_client) -> None:
app, client = await create_app_and_client(
[shared_config_ctx, event_dispatcher_ctx, background_task_ctx],
[".events"],
)
root_ctx: RootContext = app["_root.context"]
producer: EventProducer = root_ctx.event_producer
dispatcher: EventDispatcher = root_ctx.event_dispatcher
fail_handler_ctx = {}
async def test_background_task_fail(bgtask_fixture: BgtaskFixture) -> None:
background_task_manager, producer, dispatcher = bgtask_fixture
fail_handler_ctx: dict[str, Any] = {}

async def fail_sub(
context: web.Application,
context: ContextSentinel,
source: AgentId,
event: BgtaskFailedEvent,
) -> None:
fail_handler_ctx["context"] = context
fail_handler_ctx["event_name"] = event.name
update_body = attr.asdict(event) # type: ignore
fail_handler_ctx.update(**update_body)
Expand All @@ -112,15 +128,12 @@ async def _mock_task(reporter):
await reporter.update(1, message="BGTask ex1")
raise ZeroDivisionError("oops")

dispatcher.subscribe(BgtaskFailedEvent, app, fail_sub)
task_id = await root_ctx.background_task_manager.start(_mock_task, name="MockTask1234")
dispatcher.subscribe(BgtaskFailedEvent, ContextSentinel.TOKEN, fail_sub)
task_id = await background_task_manager.start(_mock_task, name="MockTask1234")
await asyncio.sleep(2)
try:
assert fail_handler_ctx["task_id"] == task_id
assert fail_handler_ctx["event_name"] == "bgtask_failed"
assert fail_handler_ctx["message"] is not None
assert "ZeroDivisionError" in fail_handler_ctx["message"]
finally:
await redis_helper.execute(producer.redis_client, lambda r: r.flushdb())
await producer.close()
await dispatcher.close()

assert fail_handler_ctx["context"] is ContextSentinel.TOKEN
assert fail_handler_ctx["task_id"] == task_id
assert fail_handler_ctx["event_name"] == "bgtask_failed"
assert fail_handler_ctx["message"] is not None
assert "ZeroDivisionError" in fail_handler_ctx["message"]

0 comments on commit 7ae3152

Please sign in to comment.