From 11eb7c5a4d96348a22656c0018481409f2c8f06f Mon Sep 17 00:00:00 2001 From: Zeke Marffy Date: Sat, 23 Mar 2024 16:02:30 -0400 Subject: [PATCH 1/2] enable log collection --- docs/examples/introduction/inmemory_run.py | 7 +- docs/examples/schedule/intro.py | 2 +- poetry.lock | 25 +++- pyproject.toml | 1 + taskiq/cli/worker/log_collector.py | 127 ++++++++++++--------- taskiq/receiver/receiver.py | 61 ++++++---- taskiq/result/v1.py | 8 +- taskiq/result/v2.py | 8 +- taskiq/serialization.py | 12 +- tests/cli/worker/test_log_collector.py | 47 ++++---- 10 files changed, 175 insertions(+), 123 deletions(-) diff --git a/docs/examples/introduction/inmemory_run.py b/docs/examples/introduction/inmemory_run.py index fdc1e7d2..2e504187 100644 --- a/docs/examples/introduction/inmemory_run.py +++ b/docs/examples/introduction/inmemory_run.py @@ -1,13 +1,17 @@ # broker.py import asyncio +import logging from taskiq import InMemoryBroker +from taskiq.task import AsyncTaskiqTask broker = InMemoryBroker() +task_logger = logging.getLogger("taskiq.tasklogger") @broker.task async def add_one(value: int) -> int: + task_logger.info(f"Adding 1 to {value}") return value + 1 @@ -17,10 +21,11 @@ async def main() -> None: # Send the task to the broker. task = await add_one.kiq(1) # Wait for the result. - result = await task.wait_result(timeout=2) + result = await task.wait_result(with_logs=True) print(f"Task execution took: {result.execution_time} seconds.") if not result.is_err: print(f"Returned value: {result.return_value}") + print(f"Logs: {result.log}") else: print("Error found while executing task.") await broker.shutdown() diff --git a/docs/examples/schedule/intro.py b/docs/examples/schedule/intro.py index 2faf91cb..039170d9 100644 --- a/docs/examples/schedule/intro.py +++ b/docs/examples/schedule/intro.py @@ -1,7 +1,7 @@ from taskiq_aio_pika import AioPikaBroker -from taskiq.schedule_sources import LabelScheduleSource from taskiq import TaskiqScheduler +from taskiq.schedule_sources import LabelScheduleSource broker = AioPikaBroker("amqp://guest:guest@localhost:5672/") diff --git a/poetry.lock b/poetry.lock index 036a979e..97f63862 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -63,6 +63,17 @@ files = [ [package.extras] tzdata = ["tzdata"] +[[package]] +name = "bidict" +version = "0.23.1" +description = "The bidirectional mapping library for Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5"}, + {file = "bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71"}, +] + [[package]] name = "black" version = "22.12.0" @@ -1121,6 +1132,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1128,8 +1140,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1146,6 +1165,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1153,6 +1173,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -1573,4 +1594,4 @@ zmq = ["pyzmq"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "d324e106ec5335886db06fbe8b91aabdc7ffa8547bc075a3c4e7275f1be8d4c7" +content-hash = "85e44290033f48a4f24cd6996ef64e8f3e660bcc4ddb68687e1be0de170f6473" diff --git a/pyproject.toml b/pyproject.toml index 373263e9..85b19265 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ pytz = "*" orjson = { version = "^3.9.9", optional = true } msgpack = { version = "^1.0.7", optional = true } cbor2 = { version = "^5.4.6", optional = true } +bidict = "^0.23.1" [tool.poetry.dev-dependencies] pytest = "^7.1.2" diff --git a/taskiq/cli/worker/log_collector.py b/taskiq/cli/worker/log_collector.py index 6ae72a99..f80e2fc1 100644 --- a/taskiq/cli/worker/log_collector.py +++ b/taskiq/cli/worker/log_collector.py @@ -1,63 +1,80 @@ +import asyncio import logging -import sys -from contextlib import contextmanager -from typing import IO, Any, Generator, List, TextIO +from logging import LogRecord +from typing import Dict, List, Union +from bidict import bidict -class Redirector: - """A class to write to multiple streams.""" - def __init__(self, *streams: IO[Any]) -> None: - self.streams = streams +class TaskiqLogHandler(logging.Handler): + """Log handler class.""" - def write(self, message: Any) -> None: + def __init__(self, level: Union[int, str] = 0) -> None: + self.stream: Dict[Union[str, None], List[logging.LogRecord]] = {} + self._associations: bidict[Union[str, None], Union[str, None]] = bidict() + super().__init__(level) + + @staticmethod + def _get_async_task_name() -> Union[str, None]: + try: + task = asyncio.current_task() + except RuntimeError: + return None + else: + if task: + return task.get_name() + + raise RuntimeError + + def associate(self, task_id: str) -> None: + """ + Associate the current async task with the Taskiq task ID. + + :param task_id: The Taskiq task ID. + :type task_id: str + """ + async_task_name = self._get_async_task_name() + self._associations[task_id] = async_task_name + + def retrieve_logs(self, task_id: str) -> List[LogRecord]: + """ + Collect logs. + + Collect the logs of a Taskiq task and return + them after removing them from memory. + + :param task_id: The Taskiq task ID + :type task_id: str + :return: A list of LogRecords + :rtype: List[LogRecord] + """ + async_task_name = self._associations[task_id] + try: + stream = self.stream[async_task_name] + except KeyError: + stream = [] + else: + del self._associations[task_id] + return stream + + def emit(self, record: LogRecord) -> None: """ - This write request writes to all available streams. + Collect an outputted log record. - :param message: message to write. + :param record: The log record to collect. + :type record: LogRecord """ - for stream in self.streams: - stream.write(message) - - -@contextmanager -def log_collector( - new_target: TextIO, - custom_format: str, -) -> "Generator[TextIO, None, None]": - """ - Context manager to collect logs. - - This useful class redirects all logs - from stdout, stderr and root logger - to some new target. - - It can be used like this: - - >>> logs = io.StringIO() - >>> with log_collector(logs, "%(levelname)s %(message)s"): - >>> print("A") - >>> - >>> print(f"Collected logs: {logs.get_value()}") - - :param new_target: new target for logs. All - logs are written in new_target. - :param custom_format: custom format for - collected logging calls. - :yields: new target. - """ - old_targets: "List[TextIO]" = [] - log_handler = logging.StreamHandler(new_target) - log_handler.setFormatter(logging.Formatter(custom_format)) - - old_targets.extend([sys.stdout, sys.stderr]) - logging.root.addHandler(log_handler) - sys.stdout = Redirector(new_target, sys.stdout) # type: ignore - sys.stderr = Redirector(new_target, sys.stderr) # type: ignore - - try: - yield new_target - finally: - sys.stderr = old_targets.pop() - sys.stdout = old_targets.pop() - logging.root.removeHandler(log_handler) + try: + async_task_name = self._get_async_task_name() + except RuntimeError: + # If not in an async context, do nothing + return + record.async_task_name = async_task_name + try: + record.task_id = self._associations.inverse[async_task_name] + except KeyError: + return + try: + self.stream[async_task_name].append(record) + except KeyError: + self.stream[async_task_name] = [record] diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 73a56ed3..0f2af95e 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -1,7 +1,8 @@ import asyncio import inspect +import logging from concurrent.futures import Executor -from logging import getLogger +from logging import Formatter, getLogger from time import time from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints @@ -11,6 +12,7 @@ from taskiq.abc.broker import AckableMessage, AsyncBroker from taskiq.abc.middleware import TaskiqMiddleware from taskiq.acks import AcknowledgeType +from taskiq.cli.worker.log_collector import TaskiqLogHandler from taskiq.context import Context from taskiq.exceptions import NoResultError from taskiq.message import TaskiqMessage @@ -20,6 +22,8 @@ from taskiq.utils import maybe_awaitable logger = getLogger(__name__) +task_logger = getLogger("taskiq.tasklogger") +task_logger.setLevel(logging.DEBUG) QUEUE_DONE = b"-1" @@ -79,6 +83,12 @@ def __init__( "can result in undefined behavior", ) self.sem_prefetch = asyncio.Semaphore(max_prefetch) + self._logging_handler = TaskiqLogHandler(logging.DEBUG) + self._logging_formatter = Formatter( + fmt="[%(asctime)s] [%(name)s] [%(levelname)s] > %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + task_logger.addHandler(self._logging_handler) async def callback( # noqa: C901, PLR0912 self, @@ -236,6 +246,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 # Start a timer. start_time = time() + log = None try: # We put kwargs resolving here, # to be able to catch any exception (for example ), @@ -245,26 +256,32 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 # We udpate kwargs with kwargs from network. kwargs.update(message.kwargs) is_coroutine = True - # If the function is a coroutine, we await it. - if asyncio.iscoroutinefunction(target): - target_future = target(*message.args, **kwargs) - else: - is_coroutine = False - # If this is a synchronous function, we - # run it in executor. - target_future = loop.run_in_executor( - self.executor, - _run_sync, - target, - message.args, - kwargs, - ) - timeout = message.labels.get("timeout") - if timeout is not None: - if not is_coroutine: - logger.warning("Timeouts for sync tasks don't work in python well.") - target_future = asyncio.wait_for(target_future, float(timeout)) - returned = await target_future + self._logging_handler.associate(message.task_id) + try: + # If the function is a coroutine, we await it. + if asyncio.iscoroutinefunction(target): + target_future = target(*message.args, **kwargs) + else: + is_coroutine = False + # If this is a synchronous function, we + # run it in executor. + target_future = loop.run_in_executor( + self.executor, + _run_sync, + target, + message.args, + kwargs, + ) + timeout = message.labels.get("timeout") + if timeout is not None: + if not is_coroutine: + logger.warning( + "Timeouts for sync tasks don't work in python well.", + ) + target_future = asyncio.wait_for(target_future, float(timeout)) + returned = await target_future + finally: + log = self._logging_handler.retrieve_logs(message.task_id) except NoResultError as no_res_exc: found_exception = no_res_exc logger.warning( @@ -294,7 +311,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 # Assemble result. result: "TaskiqResult[Any]" = TaskiqResult( is_err=found_exception is not None, - log=None, + log=log, return_value=returned, execution_time=round(execution_time, 2), error=found_exception, diff --git a/taskiq/result/v1.py b/taskiq/result/v1.py index 95297053..e1ce7d75 100644 --- a/taskiq/result/v1.py +++ b/taskiq/result/v1.py @@ -1,7 +1,8 @@ import json import pickle from functools import partial -from typing import Any, Callable, Dict, Generic, Optional, TypeVar +from logging import LogRecord +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar from pydantic import Field, validator from pydantic.generics import GenericModel @@ -27,10 +28,7 @@ class TaskiqResult(GenericModel, Generic[_ReturnType]): """Result of a remote task invocation.""" is_err: bool - # Log is a deprecated field. It would be removed in future - # releases of not, if we find a way to capture logs in async - # environment. - log: Optional[str] = None + log: Optional[List[LogRecord]] = None return_value: _ReturnType execution_time: float labels: Dict[str, Any] = Field(default_factory=dict) diff --git a/taskiq/result/v2.py b/taskiq/result/v2.py index 6294a2e3..d2e40e21 100644 --- a/taskiq/result/v2.py +++ b/taskiq/result/v2.py @@ -1,6 +1,7 @@ import json import pickle -from typing import Any, Dict, Generic, Optional, TypeVar +from logging import LogRecord +from typing import Any, Dict, Generic, List, Optional, TypeVar from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from typing_extensions import Self @@ -14,10 +15,7 @@ class TaskiqResult(BaseModel, Generic[_ReturnType]): """Result of a remote task invocation.""" is_err: bool - # Log is a deprecated field. It would be removed in future - # releases of not, if we find a way to capture logs in async - # environment. - log: Optional[str] = None + log: Optional[List[LogRecord]] = None return_value: _ReturnType execution_time: float labels: Dict[str, Any] = Field(default_factory=dict) diff --git a/taskiq/serialization.py b/taskiq/serialization.py index 1aecfed9..7255ddf3 100644 --- a/taskiq/serialization.py +++ b/taskiq/serialization.py @@ -2,17 +2,7 @@ import traceback from inspect import getmro from itertools import takewhile -from typing import ( - Any, - Generic, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - Union, -) +from typing import Any, Generic, Iterable, List, Optional, Set, Tuple, Type, Union import pydantic from typing_extensions import Protocol, TypeVar, runtime_checkable diff --git a/tests/cli/worker/test_log_collector.py b/tests/cli/worker/test_log_collector.py index 7bf1288b..21105217 100644 --- a/tests/cli/worker/test_log_collector.py +++ b/tests/cli/worker/test_log_collector.py @@ -1,26 +1,31 @@ +import asyncio import logging -import sys -from io import StringIO -from taskiq.cli.worker.log_collector import log_collector +import pytest +from taskiq.cli.worker.log_collector import TaskiqLogHandler -def test_log_collector_std_success() -> None: - """Tests that stdout and stderr calls are collected correctly.""" - log = StringIO() - with log_collector(log, "%(message)s"): - print("log1") # noqa: T201 - print("log2", file=sys.stderr) # noqa: T201 - assert log.getvalue() == "log1\nlog2\n" - -def test_log_collector_logging_success() -> None: - """Tests that logging calls are collected correctly.""" - log = StringIO() - with log_collector(log, "%(levelname)s %(message)s"): - logger = logging.getLogger(__name__) - logger.setLevel(logging.DEBUG) - logger.info("log1") - logger.warning("log2") - logger.debug("log3") - assert log.getvalue() == "INFO log1\nWARNING log2\nDEBUG log3\n" +@pytest.mark.anyio +async def test_log_collector_success() -> None: + """Tests that logs are collected correctly.""" + handler = TaskiqLogHandler() + logger = logging.getLogger("taskiq.tasklogger") + logger.addHandler(handler) + logger.setLevel(logging.DEBUG) + handler.associate("someid") + logger.info("Thing 1") + logger.info("Thing 2") + try: + task = asyncio.current_task() + except RuntimeError: + return + else: + if task: + task_name = task.get_name() + else: + raise RuntimeError + assert [record.message for record in handler.stream[task_name]] == [ + "Thing 1", + "Thing 2", + ] From dd0d0116ec456e009203e1276a95a858397b0b11 Mon Sep 17 00:00:00 2001 From: Zeke Marffy Date: Sat, 23 Mar 2024 19:43:52 -0400 Subject: [PATCH 2/2] add another test for log collection format logs properly --- taskiq/cli/worker/log_collector.py | 12 +++--- taskiq/receiver/receiver.py | 10 +++-- tests/receiver/test_receiver.py | 60 +++++++++++++++++++++++++++++- 3 files changed, 69 insertions(+), 13 deletions(-) diff --git a/taskiq/cli/worker/log_collector.py b/taskiq/cli/worker/log_collector.py index f80e2fc1..1448bd69 100644 --- a/taskiq/cli/worker/log_collector.py +++ b/taskiq/cli/worker/log_collector.py @@ -24,7 +24,7 @@ def _get_async_task_name() -> Union[str, None]: if task: return task.get_name() - raise RuntimeError + return None def associate(self, task_id: str) -> None: """ @@ -59,17 +59,15 @@ def retrieve_logs(self, task_id: str) -> List[LogRecord]: def emit(self, record: LogRecord) -> None: """ - Collect an outputted log record. + Collect a log record. :param record: The log record to collect. :type record: LogRecord """ - try: - async_task_name = self._get_async_task_name() - except RuntimeError: - # If not in an async context, do nothing + self.format(record) + async_task_name = self._get_async_task_name() + if not async_task_name: return - record.async_task_name = async_task_name try: record.task_id = self._associations.inverse[async_task_name] except KeyError: diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 0f2af95e..f9562ac9 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -84,9 +84,11 @@ def __init__( ) self.sem_prefetch = asyncio.Semaphore(max_prefetch) self._logging_handler = TaskiqLogHandler(logging.DEBUG) - self._logging_formatter = Formatter( - fmt="[%(asctime)s] [%(name)s] [%(levelname)s] > %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", + self._logging_handler.setFormatter( + Formatter( + fmt="[%(asctime)s] [%(name)s] [%(levelname)s] > %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ), ) task_logger.addHandler(self._logging_handler) @@ -246,7 +248,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 # Start a timer. start_time = time() - log = None + log = [] try: # We put kwargs resolving here, # to be able to catch any exception (for example ), diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 6b79e325..025655e5 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -1,15 +1,18 @@ import asyncio +import logging import random import time from concurrent.futures import ThreadPoolExecutor from typing import Any, ClassVar, List, Optional import pytest +from anyio import sleep from taskiq_dependencies import Depends from taskiq.abc.broker import AckableMessage, AsyncBroker from taskiq.abc.middleware import TaskiqMiddleware from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.cli.worker.log_collector import TaskiqLogHandler from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError from taskiq.message import TaskiqMessage from taskiq.receiver import Receiver @@ -447,14 +450,22 @@ async def task_no_result() -> int: async def test_result() -> None: broker = InMemoryBroker() + handler = TaskiqLogHandler() + logger = logging.getLogger("taskiq.tasklogger") + logger.addHandler(handler) + logger.setLevel(logging.DEBUG) + @broker.task - async def task_no_result() -> str: + async def task_with_result() -> str: + logger.info("Some stuff") return "some value" - task = await task_no_result.kiq() + task = await task_with_result.kiq() resp = await task.wait_result(timeout=1) assert resp.return_value == "some value" + assert resp.log is not None + assert resp.log[0].message == "Some stuff" assert not broker._running_tasks @@ -472,3 +483,48 @@ async def task_no_result() -> str: assert resp.return_value is None assert not broker._running_tasks assert isinstance(resp.error, ValueError) + + +@pytest.mark.anyio +async def test_concurrent_logs() -> None: + broker = InMemoryBroker() + + handler = TaskiqLogHandler() + logger = logging.getLogger("taskiq.tasklogger") + logger.addHandler(handler) + logger.setLevel(logging.DEBUG) + + @broker.task + async def task1() -> int: + await sleep(1) + logger.info("Some stuff") + await sleep(1) + logger.info("End some stuff") + return 1 + + @broker.task + async def task2() -> int: + logger.info("Some more stuff") + logger.info("End more stuff") + return 2 + + t1 = await task1.kiq() + t2 = await task2.kiq() + await sleep(1) + t3 = await task1.kiq() + await sleep(1) + t4 = await task2.kiq() + resp1 = await t1.wait_result(timeout=5) + resp2 = await t2.wait_result(timeout=5) + resp3 = await t3.wait_result(timeout=5) + resp4 = await t4.wait_result(timeout=5) + + assert ( + resp1.log is not None + and resp2.log is not None + and resp3.log is not None + and resp4.log is not None + ) + assert "some" in resp1.log[1].message and "more" in resp2.log[1].message + assert "some" in resp3.log[1].message and "more" in resp4.log[1].message + assert not broker._running_tasks