diff --git a/taskiq/acks.py b/taskiq/acks.py index fddb29b..548340f 100644 --- a/taskiq/acks.py +++ b/taskiq/acks.py @@ -1,8 +1,25 @@ +import enum from typing import Awaitable, Callable, Union from pydantic import BaseModel +@enum.unique +class AcknowledgeType(enum.StrEnum): + """Enum with possible acknowledge times.""" + + # The message is acknowledged right when it's received, + # before it's executed. + WHEN_RECEIVED = enum.auto() + # This option means that the message will be + # acknowledged right after it's executed. + WHEN_EXECUTED = enum.auto() + # This option means that the message will be + # acknowledged when the task will be saved + # only after it's saved in the result backend. + WHEN_SAVED = enum.auto() + + class AckableMessage(BaseModel): """ Message that can be acknowledged. diff --git a/taskiq/api/receiver.py b/taskiq/api/receiver.py index ac6959f..567e56a 100644 --- a/taskiq/api/receiver.py +++ b/taskiq/api/receiver.py @@ -1,9 +1,10 @@ import asyncio from concurrent.futures import ThreadPoolExecutor from logging import getLogger -from typing import Type +from typing import Optional, Type from taskiq.abc.broker import AsyncBroker +from taskiq.acks import AcknowledgeType from taskiq.receiver.receiver import Receiver logger = getLogger("taskiq.receiver") @@ -18,6 +19,7 @@ async def run_receiver_task( max_prefetch: int = 0, propagate_exceptions: bool = True, run_startup: bool = False, + ack_time: Optional[AcknowledgeType] = None, ) -> None: """ Function to run receiver programmatically. @@ -71,6 +73,7 @@ def on_exit(_: Receiver) -> None: max_prefetch=max_prefetch, propagate_exceptions=propagate_exceptions, on_exit=on_exit, + ack_type=ack_time, ) await receiver.listen() except asyncio.CancelledError: diff --git a/taskiq/cli/worker/args.py b/taskiq/cli/worker/args.py index 88e20cd..8de0b3c 100644 --- a/taskiq/cli/worker/args.py +++ b/taskiq/cli/worker/args.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from typing import List, Optional, Sequence, Tuple +from taskiq.acks import AcknowledgeType from taskiq.cli.common_args import LogLevel @@ -41,6 +42,7 @@ class WorkerArgs: max_prefetch: int = 0 no_propagate_errors: bool = False max_fails: int = -1 + ack_type: AcknowledgeType = AcknowledgeType.WHEN_SAVED @classmethod def from_cli( @@ -187,6 +189,13 @@ def from_cli( default=-1, help="Maximum number of child process exits.", ) + parser.add_argument( + "--ack-type", + type=lambda value: AcknowledgeType(value.lower()), + default=AcknowledgeType.WHEN_SAVED, + choices=[ack_type.name.lower() for ack_type in AcknowledgeType], + help="When to acknowledge message.", + ) namespace = parser.parse_args(args) return WorkerArgs(**namespace.__dict__) diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index 02bcb36..09c71b0 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -141,6 +141,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None: max_async_tasks=args.max_async_tasks, max_prefetch=args.max_prefetch, propagate_exceptions=not args.no_propagate_errors, + ack_type=args.ack_type, **receiver_kwargs, # type: ignore ) loop.run_until_complete(receiver.listen()) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index e5584f4..37c14a4 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -10,6 +10,7 @@ from taskiq.abc.broker import AckableMessage, AsyncBroker from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.acks import AcknowledgeType from taskiq.context import Context from taskiq.exceptions import NoResultError from taskiq.message import TaskiqMessage @@ -53,6 +54,7 @@ def __init__( max_prefetch: int = 0, propagate_exceptions: bool = True, run_starup: bool = True, + ack_type: Optional[AcknowledgeType] = None, on_exit: Optional[Callable[["Receiver"], None]] = None, ) -> None: self.broker = broker @@ -64,6 +66,7 @@ def __init__( self.dependency_graphs: Dict[str, DependencyGraph] = {} self.propagate_exceptions = propagate_exceptions self.on_exit = on_exit + self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED self.known_tasks: Set[str] = set() for task in self.broker.get_all_tasks().values(): self._prepare_task(task.task_name, task.original_func) @@ -131,13 +134,21 @@ async def callback( # noqa: C901, PLR0912 taskiq_msg.task_id, ) + if self.ack_time == AcknowledgeType.WHEN_RECEIVED and isinstance( + message, + AckableMessage, + ): + await maybe_awaitable(message.ack()) + result = await self.run_task( target=task.original_func, message=taskiq_msg, ) - # If broker has an ability to ack messages. - if isinstance(message, AckableMessage): + if self.ack_time == AcknowledgeType.WHEN_EXECUTED and isinstance( + message, + AckableMessage, + ): await maybe_awaitable(message.ack()) for middleware in self.broker.middlewares: @@ -147,9 +158,11 @@ async def callback( # noqa: C901, PLR0912 try: if not isinstance(result.error, NoResultError): await self.broker.result_backend.set_result(taskiq_msg.task_id, result) + for middleware in self.broker.middlewares: if middleware.__class__.post_save != TaskiqMiddleware.post_save: await maybe_awaitable(middleware.post_save(taskiq_msg, result)) + except Exception as exc: logger.exception( "Can't set result in result backend. Cause: %s", @@ -159,6 +172,12 @@ async def callback( # noqa: C901, PLR0912 if raise_err: raise exc + if self.ack_time == AcknowledgeType.WHEN_EXECUTED and isinstance( + message, + AckableMessage, + ): + await maybe_awaitable(message.ack()) + async def run_task( # noqa: C901, PLR0912, PLR0915 self, target: Callable[..., Any],