diff --git a/qlib/log.py b/qlib/log.py index f7683d5116..71f870c8a0 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -31,11 +31,17 @@ def __init__(self, module_name): # this feature name conflicts with the attribute with Logger # rename it to avoid some corner cases that result in comparing `str` and `int` self.__level = 0 + # Normally this should be set to `False` to avoid duplicated logging [1]. + # However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2]. + # [1] https://github.com/microsoft/qlib/pull/1661 + # [2] https://github.com/pytest-dev/pytest/issues/3697 + self.parent_propagate = False @property def logger(self): logger = logging.getLogger(self.module_name) logger.setLevel(self.__level) + logger.parent.propagate = self.parent_propagate return logger def setLevel(self, level): diff --git a/tests/rl/test_logger.py b/tests/rl/test_logger.py index c8ceca92ad..3fdfbb8f86 100644 --- a/tests/rl/test_logger.py +++ b/tests/rl/test_logger.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - from random import randint, choice from pathlib import Path @@ -69,8 +68,9 @@ def learn(self, batch): def test_simple_env_logger(caplog): set_log_with_config(C.logging_config) + writer = ConsoleWriter() + writer.console_logger.parent_propagate = True for venv_cls_name in ["dummy", "shmem", "subproc"]: - writer = ConsoleWriter() csv_writer = CsvWriter(Path(__file__).parent / ".output") venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer]) with venv.collector_guard(): @@ -80,13 +80,12 @@ def test_simple_env_logger(caplog): output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") assert output_file.columns.tolist() == ["reward", "a", "c"] assert len(output_file) >= 30 - line_counter = 0 for line in caplog.text.splitlines(): line = line.strip() if line: line_counter += 1 - assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line) + assert re.match(r".*reward .* {2}a .* \(([456])\.\d+\) {2}c .* \((14|15|16)\.\d+\)", line) assert line_counter >= 3 @@ -137,15 +136,17 @@ def learn(self, batch): def test_logger_with_env_wrapper(): with DataQueue(list(range(20)), shuffle=False) as data_iterator: - env_wrapper_factory = lambda: EnvWrapper( - SimpleSimulator, - DummyStateInterpreter(), - DummyActionInterpreter(), - data_iterator, - logger=LogCollector(LogLevel.DEBUG), - ) - - # loglevel can be debug here because metrics can all dump into csv + + def env_wrapper_factory(): + return EnvWrapper( + SimpleSimulator, + DummyStateInterpreter(), + DummyActionInterpreter(), + data_iterator, + logger=LogCollector(LogLevel.DEBUG), + ) + + # loglevel can be debugged here because metrics can all dump into csv # otherwise, csv writer might crash csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG) venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer) @@ -155,7 +156,7 @@ def test_logger_with_env_wrapper(): output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") assert len(output_df) == 20 - # obs has a increasing trend + # obs has an increasing trend assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum() assert (output_df["test_a"] == 233).all() assert (output_df["test_b"] == 200).all()