-
-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Pavel Kirilin <[email protected]> Co-authored-by: Anton <[email protected]>
- Loading branch information
1 parent
20f92b0
commit cdb431b
Showing
6 changed files
with
280 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import enum | ||
from typing import Generic, Optional, Union | ||
|
||
from taskiq_dependencies import Depends | ||
from typing_extensions import TypeVar | ||
|
||
from taskiq.compat import IS_PYDANTIC2 | ||
from taskiq.context import Context | ||
|
||
if IS_PYDANTIC2: | ||
from pydantic import BaseModel as GenericModel | ||
else: | ||
from pydantic.generics import GenericModel # type: ignore[no-redef] | ||
|
||
|
||
_ProgressType = TypeVar("_ProgressType") | ||
|
||
|
||
class TaskState(str, enum.Enum): | ||
"""State of task execution.""" | ||
|
||
STARTED = "STARTED" | ||
FAILURE = "FAILURE" | ||
SUCCESS = "SUCCESS" | ||
RETRY = "RETRY" | ||
|
||
|
||
class TaskProgress(GenericModel, Generic[_ProgressType]): | ||
"""Progress of task execution.""" | ||
|
||
state: Union[TaskState, str] | ||
meta: Optional[_ProgressType] | ||
|
||
|
||
class ProgressTracker(Generic[_ProgressType]): | ||
"""Task's dependency to set progress.""" | ||
|
||
def __init__( | ||
self, | ||
context: Context = Depends(), | ||
) -> None: | ||
self.context = context | ||
|
||
async def set_progress( | ||
self, | ||
state: Union[TaskState, str], | ||
meta: Optional[_ProgressType] = None, | ||
) -> None: | ||
"""Set progress. | ||
:param state: TaskState or str | ||
:param meta: progress data | ||
""" | ||
if meta is None: | ||
progress = await self.get_progress() | ||
meta = progress.meta if progress else None | ||
|
||
progress = TaskProgress( | ||
state=state, | ||
meta=meta, | ||
) | ||
|
||
await self.context.broker.result_backend.set_progress( | ||
self.context.message.task_id, | ||
progress, | ||
) | ||
|
||
async def get_progress(self) -> Optional[TaskProgress[_ProgressType]]: | ||
"""Get progress.""" | ||
return await self.context.broker.result_backend.get_progress( | ||
self.context.message.task_id, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Any, Dict, Optional | ||
|
||
import pytest | ||
from pydantic import ValidationError | ||
|
||
from taskiq import ( | ||
AsyncTaskiqDecoratedTask, | ||
InMemoryBroker, | ||
TaskiqDepends, | ||
TaskiqMessage, | ||
) | ||
from taskiq.abc import AsyncBroker | ||
from taskiq.depends.progress_tracker import ProgressTracker, TaskState | ||
from taskiq.receiver import Receiver | ||
|
||
|
||
def get_receiver( | ||
broker: Optional[AsyncBroker] = None, | ||
no_parse: bool = False, | ||
max_async_tasks: Optional[int] = None, | ||
) -> Receiver: | ||
""" | ||
Returns receiver with custom broker and args. | ||
:param broker: broker, defaults to None | ||
:param no_parse: parameter to taskiq_args, defaults to False | ||
:param cli_args: Taskiq worker CLI arguments. | ||
:return: new receiver. | ||
""" | ||
if broker is None: | ||
broker = InMemoryBroker() | ||
return Receiver( | ||
broker, | ||
executor=ThreadPoolExecutor(max_workers=10), | ||
validate_params=not no_parse, | ||
max_async_tasks=max_async_tasks, | ||
) | ||
|
||
|
||
def get_message( | ||
task: AsyncTaskiqDecoratedTask[Any, Any], | ||
task_id: Optional[str] = None, | ||
*args: Any, | ||
labels: Optional[Dict[str, str]] = None, | ||
**kwargs: Dict[str, Any], | ||
) -> TaskiqMessage: | ||
if labels is None: | ||
labels = {} | ||
return TaskiqMessage( | ||
task_id=task_id or task.broker.id_generator(), | ||
task_name=task.task_name, | ||
labels=labels, | ||
args=list(args), | ||
kwargs=kwargs, | ||
) | ||
|
||
|
||
@pytest.mark.anyio | ||
@pytest.mark.parametrize( | ||
"state,meta", | ||
[ | ||
(TaskState.STARTED, "hello world!"), | ||
("retry", "retry error!"), | ||
("custom state", {"Complex": "Value"}), | ||
], | ||
) | ||
async def test_progress_tracker_ctx_raw(state: Any, meta: Any) -> None: | ||
broker = InMemoryBroker() | ||
|
||
@broker.task | ||
async def test_func(tes_val: ProgressTracker[Any] = TaskiqDepends()) -> None: | ||
await tes_val.set_progress(state, meta) | ||
|
||
kicker = await test_func.kiq() | ||
result = await kicker.wait_result() | ||
|
||
assert not result.is_err | ||
progress = await broker.result_backend.get_progress(kicker.task_id) | ||
assert progress is not None | ||
assert progress.meta == meta | ||
assert progress.state == state | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_progress_tracker_ctx_none() -> None: | ||
broker = InMemoryBroker() | ||
|
||
@broker.task | ||
async def test_func() -> None: | ||
pass | ||
|
||
kicker = await test_func.kiq() | ||
result = await kicker.wait_result() | ||
|
||
assert not result.is_err | ||
progress = await broker.result_backend.get_progress(kicker.task_id) | ||
assert progress is None | ||
|
||
|
||
@pytest.mark.anyio | ||
@pytest.mark.parametrize( | ||
"state,meta", | ||
[ | ||
(("state", "error"), 1), | ||
], | ||
) | ||
async def test_progress_tracker_validation_error(state: Any, meta: Any) -> None: | ||
broker = InMemoryBroker() | ||
|
||
@broker.task | ||
async def test_func(progress: ProgressTracker[int] = TaskiqDepends()) -> None: | ||
await progress.set_progress(state, meta) # type: ignore | ||
|
||
kicker = await test_func.kiq() | ||
result = await kicker.wait_result() | ||
with pytest.raises(ValidationError): | ||
result.raise_for_error() | ||
|
||
progress = await broker.result_backend.get_progress(kicker.task_id) | ||
assert progress is None |