From cdb431b7c9a1226c67efb914e229e3aff5da2dbc Mon Sep 17 00:00:00 2001 From: Anton Date: Thu, 13 Jun 2024 10:57:43 +0300 Subject: [PATCH] feat: set/get progress (#130) Co-authored-by: Pavel Kirilin Co-authored-by: Anton --- .github/workflows/test.yml | 2 +- taskiq/abc/result_backend.py | 28 +++++- taskiq/brokers/inmemory_broker.py | 35 ++++++- taskiq/depends/progress_tracker.py | 72 +++++++++++++++ taskiq/task.py | 26 +++++- tests/depends/test_progress_tracker.py | 121 +++++++++++++++++++++++++ 6 files changed, 280 insertions(+), 4 deletions(-) create mode 100644 taskiq/depends/progress_tracker.py create mode 100644 tests/depends/test_progress_tracker.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9093d5e3..bbc4e1f1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,7 +37,7 @@ jobs: strategy: matrix: py_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - pydantic_ver: ["<2", ">=2,<3"] + pydantic_ver: ["<2", ">=2.5,<3"] os: [ubuntu-latest, windows-latest] runs-on: "${{ matrix.os }}" steps: diff --git a/taskiq/abc/result_backend.py b/taskiq/abc/result_backend.py index 7e0ebb65..257d0b04 100644 --- a/taskiq/abc/result_backend.py +++ b/taskiq/abc/result_backend.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from taskiq.result import TaskiqResult +if TYPE_CHECKING: # pragma: no cover + from taskiq.depends.progress_tracker import TaskProgress + + _ReturnType = TypeVar("_ReturnType") @@ -50,3 +54,25 @@ async def get_result( :param with_logs: if True it will download task's logs. :return: task's return value. """ + + async def set_progress( + self, + task_id: str, + progress: "TaskProgress[Any]", + ) -> None: + """ + Saves progress. + + :param task_id: task's id. + :param progress: progress of execution. + """ + + async def get_progress( + self, + task_id: str, + ) -> "Optional[TaskProgress[Any]]": + """ + Gets progress. + + :param task_id: task's id. + """ diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 6c6eed86..544289ff 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -1,10 +1,11 @@ import asyncio from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncGenerator, Set, TypeVar +from typing import Any, AsyncGenerator, Optional, Set, TypeVar from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult +from taskiq.depends.progress_tracker import TaskProgress from taskiq.events import TaskiqEvents from taskiq.exceptions import TaskiqError from taskiq.message import BrokerMessage @@ -27,6 +28,7 @@ class InmemoryResultBackend(AsyncResultBackend[_ReturnType]): def __init__(self, max_stored_results: int = 100) -> None: self.max_stored_results = max_stored_results self.results: OrderedDict[str, TaskiqResult[_ReturnType]] = OrderedDict() + self.progress: OrderedDict[str, TaskProgress[Any]] = OrderedDict() async def set_result(self, task_id: str, result: TaskiqResult[_ReturnType]) -> None: """ @@ -79,6 +81,37 @@ async def get_result( """ return self.results[task_id] + async def set_progress( + self, + task_id: str, + progress: TaskProgress[Any], + ) -> None: + """ + Set progress of task exection. + + :param task_id: task id + :param progress: task execution progress + """ + if ( + self.max_stored_results != -1 + and len(self.progress) >= self.max_stored_results + ): + self.progress.popitem(last=False) + + self.progress[task_id] = progress + + async def get_progress( + self, + task_id: str, + ) -> Optional[TaskProgress[Any]]: + """ + Get progress of task execution. + + :param task_id: task id + :return: progress or None + """ + return self.progress.get(task_id) + class InMemoryBroker(AsyncBroker): """ diff --git a/taskiq/depends/progress_tracker.py b/taskiq/depends/progress_tracker.py new file mode 100644 index 00000000..9f0161ae --- /dev/null +++ b/taskiq/depends/progress_tracker.py @@ -0,0 +1,72 @@ +import enum +from typing import Generic, Optional, Union + +from taskiq_dependencies import Depends +from typing_extensions import TypeVar + +from taskiq.compat import IS_PYDANTIC2 +from taskiq.context import Context + +if IS_PYDANTIC2: + from pydantic import BaseModel as GenericModel +else: + from pydantic.generics import GenericModel # type: ignore[no-redef] + + +_ProgressType = TypeVar("_ProgressType") + + +class TaskState(str, enum.Enum): + """State of task execution.""" + + STARTED = "STARTED" + FAILURE = "FAILURE" + SUCCESS = "SUCCESS" + RETRY = "RETRY" + + +class TaskProgress(GenericModel, Generic[_ProgressType]): + """Progress of task execution.""" + + state: Union[TaskState, str] + meta: Optional[_ProgressType] + + +class ProgressTracker(Generic[_ProgressType]): + """Task's dependency to set progress.""" + + def __init__( + self, + context: Context = Depends(), + ) -> None: + self.context = context + + async def set_progress( + self, + state: Union[TaskState, str], + meta: Optional[_ProgressType] = None, + ) -> None: + """Set progress. + + :param state: TaskState or str + :param meta: progress data + """ + if meta is None: + progress = await self.get_progress() + meta = progress.meta if progress else None + + progress = TaskProgress( + state=state, + meta=meta, + ) + + await self.context.broker.result_backend.set_progress( + self.context.message.task_id, + progress, + ) + + async def get_progress(self) -> Optional[TaskProgress[_ProgressType]]: + """Get progress.""" + return await self.context.broker.result_backend.get_progress( + self.context.message.task_id, + ) diff --git a/taskiq/task.py b/taskiq/task.py index 691046b5..b4d2be61 100644 --- a/taskiq/task.py +++ b/taskiq/task.py @@ -1,7 +1,9 @@ import asyncio from abc import ABC, abstractmethod from time import time -from typing import TYPE_CHECKING, Any, Coroutine, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Union + +from typing_extensions import TypeVar from taskiq.exceptions import ( ResultGetError, @@ -11,6 +13,7 @@ if TYPE_CHECKING: # pragma: no cover from taskiq.abc.result_backend import AsyncResultBackend + from taskiq.depends.progress_tracker import TaskProgress from taskiq.result import TaskiqResult _ReturnType = TypeVar("_ReturnType") @@ -65,6 +68,19 @@ def wait_result( :return: TaskiqResult. """ + @abstractmethod + def get_progress( + self, + ) -> Union[ + "Optional[TaskProgress[Any]]", + Coroutine[Any, Any, "Optional[TaskProgress[Any]]"], + ]: + """ + Get task progress. + + :return: task's progress. + """ + class AsyncTaskiqTask(_Task[_ReturnType]): """AsyncTask for AsyncResultBackend.""" @@ -137,3 +153,11 @@ async def wait_result( if 0 < timeout < time() - start_time: raise TaskiqResultTimeoutError return await self.get_result(with_logs=with_logs) + + async def get_progress(self) -> "Optional[TaskProgress[Any]]": + """ + Get task progress. + + :return: task's progress. + """ + return await self.result_backend.get_progress(self.task_id) diff --git a/tests/depends/test_progress_tracker.py b/tests/depends/test_progress_tracker.py new file mode 100644 index 00000000..040381b0 --- /dev/null +++ b/tests/depends/test_progress_tracker.py @@ -0,0 +1,121 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, Optional + +import pytest +from pydantic import ValidationError + +from taskiq import ( + AsyncTaskiqDecoratedTask, + InMemoryBroker, + TaskiqDepends, + TaskiqMessage, +) +from taskiq.abc import AsyncBroker +from taskiq.depends.progress_tracker import ProgressTracker, TaskState +from taskiq.receiver import Receiver + + +def get_receiver( + broker: Optional[AsyncBroker] = None, + no_parse: bool = False, + max_async_tasks: Optional[int] = None, +) -> Receiver: + """ + Returns receiver with custom broker and args. + + :param broker: broker, defaults to None + :param no_parse: parameter to taskiq_args, defaults to False + :param cli_args: Taskiq worker CLI arguments. + :return: new receiver. + """ + if broker is None: + broker = InMemoryBroker() + return Receiver( + broker, + executor=ThreadPoolExecutor(max_workers=10), + validate_params=not no_parse, + max_async_tasks=max_async_tasks, + ) + + +def get_message( + task: AsyncTaskiqDecoratedTask[Any, Any], + task_id: Optional[str] = None, + *args: Any, + labels: Optional[Dict[str, str]] = None, + **kwargs: Dict[str, Any], +) -> TaskiqMessage: + if labels is None: + labels = {} + return TaskiqMessage( + task_id=task_id or task.broker.id_generator(), + task_name=task.task_name, + labels=labels, + args=list(args), + kwargs=kwargs, + ) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "state,meta", + [ + (TaskState.STARTED, "hello world!"), + ("retry", "retry error!"), + ("custom state", {"Complex": "Value"}), + ], +) +async def test_progress_tracker_ctx_raw(state: Any, meta: Any) -> None: + broker = InMemoryBroker() + + @broker.task + async def test_func(tes_val: ProgressTracker[Any] = TaskiqDepends()) -> None: + await tes_val.set_progress(state, meta) + + kicker = await test_func.kiq() + result = await kicker.wait_result() + + assert not result.is_err + progress = await broker.result_backend.get_progress(kicker.task_id) + assert progress is not None + assert progress.meta == meta + assert progress.state == state + + +@pytest.mark.anyio +async def test_progress_tracker_ctx_none() -> None: + broker = InMemoryBroker() + + @broker.task + async def test_func() -> None: + pass + + kicker = await test_func.kiq() + result = await kicker.wait_result() + + assert not result.is_err + progress = await broker.result_backend.get_progress(kicker.task_id) + assert progress is None + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "state,meta", + [ + (("state", "error"), 1), + ], +) +async def test_progress_tracker_validation_error(state: Any, meta: Any) -> None: + broker = InMemoryBroker() + + @broker.task + async def test_func(progress: ProgressTracker[int] = TaskiqDepends()) -> None: + await progress.set_progress(state, meta) # type: ignore + + kicker = await test_func.kiq() + result = await kicker.wait_result() + with pytest.raises(ValidationError): + result.raise_for_error() + + progress = await broker.result_backend.get_progress(kicker.task_id) + assert progress is None