Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added non-string labels in tasks. #243

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion taskiq/kicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.compat import model_dump
from taskiq.exceptions import SendTaskError
from taskiq.labels import prepare_label
from taskiq.message import TaskiqMessage
from taskiq.scheduler.created_schedule import CreatedSchedule
from taskiq.scheduler.scheduled_task import CronSpec, ScheduledTask
Expand Down Expand Up @@ -245,12 +246,14 @@ def _prepare_message(
formatted_args = []
formatted_kwargs = {}
labels = {}
labels_types = {}
for arg in args:
formatted_args.append(self._prepare_arg(arg))
for kwarg_name, kwarg_val in kwargs.items():
formatted_kwargs[kwarg_name] = self._prepare_arg(kwarg_val)

for label, label_val in self.labels.items():
labels[label] = str(label_val)
labels[label], labels_types[label] = prepare_label(label_val)

task_id = self.custom_task_id
if task_id is None:
Expand All @@ -260,6 +263,7 @@ def _prepare_message(
task_id=task_id,
task_name=self.task_name,
labels=labels,
labels_types=labels_types,
args=formatted_args,
kwargs=formatted_kwargs,
)
55 changes: 55 additions & 0 deletions taskiq/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import base64
import enum
from typing import Any, Callable, Dict, Optional, Tuple


class LabelType(enum.IntEnum):
"""Possible label types."""

ANY = enum.auto()
INT = enum.auto()
STR = enum.auto()
FLOAT = enum.auto()
BOOL = enum.auto()
BYTES = enum.auto()


_LABEL_PARSERS: Dict[LabelType, Callable[[str], Any]] = {
LabelType.INT: int,
LabelType.STR: str,
LabelType.FLOAT: float,
LabelType.BOOL: lambda x: x.lower() == "true",
LabelType.BYTES: base64.b64decode,
LabelType.ANY: lambda x: x,
}


def prepare_label(label_value: Any) -> Tuple[str, int]:
"""
Prepare label value for serialization.

:param label_value: label value to prepare.
:return: tuple of prepared label value and its type.
"""
var_type = type(label_value)
if var_type in (int, str, float, bool):
return str(label_value), LabelType[var_type.__name__.upper()].value
if var_type == bytes:
return base64.b64encode(label_value).decode(), LabelType.BYTES.value
return str(label_value), LabelType.ANY.value


def parse_label(label_value: Any, label_type: Optional[int] = None) -> Any:
"""
Parse label value from serialized format.

:param label_value: label value to parse.
:param label_type: label type.
:return: parsed label value.
"""
if label_type is None:
return label_value
label_type = LabelType(label_type)
if label_type in _LABEL_PARSERS:
return _LABEL_PARSERS[label_type](label_value)
raise ValueError(f"Unsupported label type: {label_type}")
18 changes: 17 additions & 1 deletion taskiq/message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from pydantic import BaseModel

from taskiq.labels import parse_label


class TaskiqMessage(BaseModel):
"""
Expand All @@ -15,9 +17,23 @@ class TaskiqMessage(BaseModel):
task_id: str
task_name: str
labels: Dict[str, Any]
labels_types: Optional[Dict[str, int]] = None
args: List[Any]
kwargs: Dict[str, Any]

def parse_labels(self) -> None:
"""
Parse labels.

:return: None
"""
if self.labels_types is None:
return

for label, label_type in self.labels_types.items():
if label in self.labels:
self.labels[label] = parse_label(self.labels[label], label_type)


class BrokerMessage(BaseModel):
"""Format of messages for brokers."""
Expand Down
25 changes: 16 additions & 9 deletions taskiq/middlewares/retry_middleware.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from copy import deepcopy
from logging import getLogger
from typing import Any

from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.exceptions import NoResultError
from taskiq.kicker import AsyncKicker
from taskiq.message import TaskiqMessage
from taskiq.result import TaskiqResult

Expand Down Expand Up @@ -47,25 +47,32 @@ async def on_error(
return

retry_on_error = message.labels.get("retry_on_error")
if isinstance(retry_on_error, str):
retry_on_error = retry_on_error.lower() == "true"

if retry_on_error is None:
retry_on_error = "true" if self.default_retry_label else "false"
retry_on_error = self.default_retry_label
# Check if retrying is enabled for the task.
if retry_on_error.lower() != "true":
if not retry_on_error:
return
new_msg = deepcopy(message)

kicker: AsyncKicker[Any, Any] = AsyncKicker(
task_name=message.task_name,
broker=self.broker,
labels=message.labels,
).with_task_id(message.task_id)

# Getting number of previous retries.
retries = int(new_msg.labels.get("_retries", 0)) + 1
new_msg.labels["_retries"] = str(retries)
max_retries = int(new_msg.labels.get("max_retries", self.default_retry_count))
retries = int(message.labels.get("_retries", 0)) + 1
kicker.with_labels(_retries=retries)
max_retries = int(message.labels.get("max_retries", self.default_retry_count))

if retries < max_retries:
logger.info(
"Task '%s' invocation failed. Retrying.",
message.task_name,
)
broker_message = self.broker.formatter.dumps(message=new_msg)
await self.broker.kick(broker_message)
await kicker.kiq(*message.args, **message.kwargs)

if self.no_result_on_retry:
result.error = NoResultError()
Expand Down
1 change: 1 addition & 0 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def callback( # noqa: C901, PLR0912
message_data = message.data if isinstance(message, AckableMessage) else message
try:
taskiq_msg = self.broker.formatter.loads(message=message_data)
taskiq_msg.parse_labels()
except Exception as exc:
logger.warning(
"Cannot parse message: %s. Skipping execution.\n %s",
Expand Down
2 changes: 1 addition & 1 deletion taskiq/result/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TaskiqResult(GenericModel, Generic[_ReturnType]):
log: Optional[str] = None
return_value: _ReturnType
execution_time: float
labels: Dict[str, str] = Field(default_factory=dict)
labels: Dict[str, Any] = Field(default_factory=dict)

error: Optional[BaseException] = None

Expand Down
2 changes: 1 addition & 1 deletion taskiq/result/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TaskiqResult(BaseModel, Generic[_ReturnType]):
log: Optional[str] = None
return_value: _ReturnType
execution_time: float
labels: Dict[str, str] = Field(default_factory=dict)
labels: Dict[str, Any] = Field(default_factory=dict)

error: Optional[BaseException] = None

Expand Down
10 changes: 8 additions & 2 deletions taskiq/serializers/json_serializer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
from json import dumps, loads
from typing import Any
from typing import Any, Callable, Optional

from taskiq.abc.serializer import TaskiqSerializer


class JSONSerializer(TaskiqSerializer):
"""Default taskiq serizalizer."""

def __init__(self, default: Optional[Callable[..., None]] = None) -> None:
self.default = default

def dumpb(self, value: Any) -> bytes:
"""
Dumps taskiq message to some broker message format.

:param message: message to send.
:return: Dumped message.
"""
return dumps(value).encode()
return dumps(
value,
default=self.default,
).encode()

def loadb(self, value: bytes) -> Any:
"""
Expand Down
1 change: 1 addition & 0 deletions tests/formatters/test_json_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ async def test_json_dumps() -> None:
message=(
b'{"task_id":"task-id","task_name":"task.name",'
b'"labels":{"label1":1,"label2":"text"},'
b'"labels_types":null,'
b'"args":[1,"a"],"kwargs":{"p1":"v1"}}'
),
labels={"label1": 1, "label2": "text"},
Expand Down
1 change: 1 addition & 0 deletions tests/formatters/test_proxy_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async def test_proxy_dumps() -> None:
message=(
b'{"task_id": "task-id", "task_name": "task.name", '
b'"labels": {"label1": 1, "label2": "text"}, '
b'"labels_types": null, '
b'"args": [1, "a"], "kwargs": {"p1": "v1"}}'
),
labels={"label1": 1, "label2": "text"},
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ env_list =
skip_install = true
allowlist_externals = poetry
commands_pre =
poetry install
poetry install --all-extras
commands =
poetry run pytest -vv -n auto
Loading