Skip to content

Commit

Permalink
feat: changed approach to dealing with idle tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton committed Jul 25, 2024
1 parent 49c0408 commit 25ffb83
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 2 deletions.
11 changes: 9 additions & 2 deletions taskiq/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
from contextlib import _AsyncGeneratorContextManager
from typing import TYPE_CHECKING, Callable, Optional

from taskiq.abc.broker import AsyncBroker
from taskiq.exceptions import NoResultError, TaskRejectedError
Expand All @@ -11,11 +12,17 @@
class Context:
"""Context class."""

def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None:
def __init__(
self,
message: TaskiqMessage,
broker: AsyncBroker,
idle: "Callable[[Optional[int]], _AsyncGeneratorContextManager[None]]",
) -> None:
self.message = message
self.broker = broker
self.state: "TaskiqState" = None # type: ignore
self.state = broker.state
self.idle = idle

async def requeue(self) -> None:
"""
Expand Down
19 changes: 19 additions & 0 deletions taskiq/depends/task_idler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional

from taskiq_dependencies import Depends

from taskiq.context import Context


class TaskIdler:
"""Task's dependency to idle task."""

def __init__(self, context: Context = Depends()) -> None:
self.context = context

@asynccontextmanager
async def __call__(self, timeout: Optional[int] = None) -> AsyncIterator[None]:
"""Idle task."""
async with self.context.idle(timeout):
yield
76 changes: 76 additions & 0 deletions tests/depends/test_task_idler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import asyncio
from asyncio.exceptions import CancelledError

import anyio
import pytest
from taskiq_dependencies import Depends

from taskiq.api.receiver import run_receiver_task
from taskiq.brokers.inmemory_broker import InmemoryResultBackend
from taskiq.depends.task_idler import TaskIdler
from tests.utils import AsyncQueueBroker


@pytest.mark.anyio
async def test_task_idler() -> None:
broker = AsyncQueueBroker().with_result_backend(InmemoryResultBackend())
kicked = 0
desired_kicked = 20

@broker.task(timeout=1)
async def test_func(idle: TaskIdler = Depends()) -> None:
nonlocal kicked
async with idle():
await asyncio.sleep(0.5)
kicked += 1

receiver_task = asyncio.create_task(run_receiver_task(broker, max_async_tasks=1))

tasks = []
for _ in range(desired_kicked):
tasks.append(await test_func.kiq())

with anyio.fail_after(1):
for task in tasks:
await task.wait_result(check_interval=0.01)

receiver_task.cancel()
assert kicked == desired_kicked


@pytest.mark.anyio
async def test_task_idler_task_cancelled() -> None:
broker = AsyncQueueBroker().with_result_backend(InmemoryResultBackend())
kicked = 0
desired_kicked = 20

@broker.task(timeout=0.2)
async def test_func_timeout(idle: TaskIdler = Depends()) -> None:
nonlocal kicked
try:
async with idle():
await asyncio.sleep(2)
except CancelledError:
kicked += 1
raise

@broker.task(timeout=2)
async def test_func(idle: TaskIdler = Depends()) -> None:
nonlocal kicked
async with idle():
await asyncio.sleep(0.5)
kicked += 1

receiver_task = asyncio.create_task(run_receiver_task(broker, max_async_tasks=1))

tasks = []
tasks.append(await test_func_timeout.kiq())
for _ in range(desired_kicked):
tasks.append(await test_func.kiq())

with anyio.fail_after(1):
for task in tasks:
await task.wait_result(check_interval=0.01)

receiver_task.cancel()
assert kicked == desired_kicked + 1

0 comments on commit 25ffb83

Please sign in to comment.