Skip to content

Commit

Permalink
Added ack config.
Browse files Browse the repository at this point in the history
Signed-off-by: Pavel Kirilin <[email protected]>
  • Loading branch information
s3rius committed Nov 30, 2023
1 parent fb27ddf commit a49d9d4
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 3 deletions.
17 changes: 17 additions & 0 deletions taskiq/acks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
import enum
from typing import Awaitable, Callable, Union

from pydantic import BaseModel


@enum.unique
class AcknowledgeType(enum.StrEnum):
"""Enum with possible acknowledge times."""

# The message is acknowledged right when it's received,
# before it's executed.
WHEN_RECEIVED = enum.auto()
# This option means that the message will be
# acknowledged right after it's executed.
WHEN_EXECUTED = enum.auto()
# This option means that the message will be
# acknowledged when the task will be saved
# only after it's saved in the result backend.
WHEN_SAVED = enum.auto()


class AckableMessage(BaseModel):
"""
Message that can be acknowledged.
Expand Down
5 changes: 4 additions & 1 deletion taskiq/api/receiver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from typing import Type
from typing import Optional, Type

from taskiq.abc.broker import AsyncBroker
from taskiq.acks import AcknowledgeType
from taskiq.receiver.receiver import Receiver

logger = getLogger("taskiq.receiver")
Expand All @@ -18,6 +19,7 @@ async def run_receiver_task(
max_prefetch: int = 0,
propagate_exceptions: bool = True,
run_startup: bool = False,
ack_time: Optional[AcknowledgeType] = None,
) -> None:
"""
Function to run receiver programmatically.
Expand Down Expand Up @@ -71,6 +73,7 @@ def on_exit(_: Receiver) -> None:
max_prefetch=max_prefetch,
propagate_exceptions=propagate_exceptions,
on_exit=on_exit,
ack_type=ack_time,
)
await receiver.listen()
except asyncio.CancelledError:
Expand Down
9 changes: 9 additions & 0 deletions taskiq/cli/worker/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Tuple

from taskiq.acks import AcknowledgeType
from taskiq.cli.common_args import LogLevel


Expand Down Expand Up @@ -41,6 +42,7 @@ class WorkerArgs:
max_prefetch: int = 0
no_propagate_errors: bool = False
max_fails: int = -1
ack_type: AcknowledgeType = AcknowledgeType.WHEN_SAVED

@classmethod
def from_cli(
Expand Down Expand Up @@ -187,6 +189,13 @@ def from_cli(
default=-1,
help="Maximum number of child process exits.",
)
parser.add_argument(
"--ack-type",
type=lambda value: AcknowledgeType(value.lower()),
default=AcknowledgeType.WHEN_SAVED,
choices=[ack_type.name.lower() for ack_type in AcknowledgeType],
help="When to acknowledge message.",
)

namespace = parser.parse_args(args)
return WorkerArgs(**namespace.__dict__)
1 change: 1 addition & 0 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
max_async_tasks=args.max_async_tasks,
max_prefetch=args.max_prefetch,
propagate_exceptions=not args.no_propagate_errors,
ack_type=args.ack_type,
**receiver_kwargs, # type: ignore
)
loop.run_until_complete(receiver.listen())
Expand Down
23 changes: 21 additions & 2 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from taskiq.abc.broker import AckableMessage, AsyncBroker
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.acks import AcknowledgeType
from taskiq.context import Context
from taskiq.exceptions import NoResultError
from taskiq.message import TaskiqMessage
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
max_prefetch: int = 0,
propagate_exceptions: bool = True,
run_starup: bool = True,
ack_type: Optional[AcknowledgeType] = None,
on_exit: Optional[Callable[["Receiver"], None]] = None,
) -> None:
self.broker = broker
Expand All @@ -64,6 +66,7 @@ def __init__(
self.dependency_graphs: Dict[str, DependencyGraph] = {}
self.propagate_exceptions = propagate_exceptions
self.on_exit = on_exit
self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED
self.known_tasks: Set[str] = set()
for task in self.broker.get_all_tasks().values():
self._prepare_task(task.task_name, task.original_func)
Expand Down Expand Up @@ -131,13 +134,21 @@ async def callback( # noqa: C901, PLR0912
taskiq_msg.task_id,
)

if self.ack_time == AcknowledgeType.WHEN_RECEIVED and isinstance(
message,
AckableMessage,
):
await maybe_awaitable(message.ack())

result = await self.run_task(
target=task.original_func,
message=taskiq_msg,
)

# If broker has an ability to ack messages.
if isinstance(message, AckableMessage):
if self.ack_time == AcknowledgeType.WHEN_EXECUTED and isinstance(
message,
AckableMessage,
):
await maybe_awaitable(message.ack())

for middleware in self.broker.middlewares:
Expand All @@ -147,9 +158,11 @@ async def callback( # noqa: C901, PLR0912
try:
if not isinstance(result.error, NoResultError):
await self.broker.result_backend.set_result(taskiq_msg.task_id, result)

for middleware in self.broker.middlewares:
if middleware.__class__.post_save != TaskiqMiddleware.post_save:
await maybe_awaitable(middleware.post_save(taskiq_msg, result))

except Exception as exc:
logger.exception(
"Can't set result in result backend. Cause: %s",
Expand All @@ -159,6 +172,12 @@ async def callback( # noqa: C901, PLR0912
if raise_err:
raise exc

if self.ack_time == AcknowledgeType.WHEN_EXECUTED and isinstance(
message,
AckableMessage,
):
await maybe_awaitable(message.ack())

async def run_task( # noqa: C901, PLR0912, PLR0915
self,
target: Callable[..., Any],
Expand Down

0 comments on commit a49d9d4

Please sign in to comment.