diff --git a/tests/task/test_task.py b/tests/task/test_task.py index 9c83d427..230d95cd 100644 --- a/tests/task/test_task.py +++ b/tests/task/test_task.py @@ -289,3 +289,169 @@ class MyService(pydase.DataService): await asyncio.sleep(0.01) assert "Task 'my_task' was cancelled" in caplog.text + + +@pytest.mark.asyncio(scope="function") +async def test_restart_on_failure(caplog: LogCaptureFixture) -> None: + class MyService(pydase.DataService): + @task(restart_on_failure=True, restart_sec=0.1) + async def my_task(self) -> None: + logger.info("Triggered task.") + raise Exception("Task failure") + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.01) + assert "Task 'my_task' encountered an exception" in caplog.text + caplog.clear() + await asyncio.sleep(0.1) + assert service_instance.my_task.status == TaskStatus.RUNNING + assert "Task 'my_task' encountered an exception" in caplog.text + assert "Triggered task." in caplog.text + + +@pytest.mark.asyncio(scope="function") +async def test_restart_sec(caplog: LogCaptureFixture) -> None: + class MyService(pydase.DataService): + @task(restart_on_failure=True, restart_sec=0.1) + async def my_task(self) -> None: + logger.info("Triggered task.") + raise Exception("Task failure") + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.001) + assert "Triggered task." in caplog.text + caplog.clear() + await asyncio.sleep(0.05) + assert "Triggered task." not in caplog.text + await asyncio.sleep(0.05) + assert "Triggered task." in caplog.text # Ensures the task restarted after 0.2s + + +@pytest.mark.asyncio(scope="function") +async def test_exceeding_start_limit_interval_sec_and_burst( + caplog: LogCaptureFixture, +) -> None: + class MyService(pydase.DataService): + @task( + restart_on_failure=True, + restart_sec=0.0, + start_limit_interval_sec=1.0, + start_limit_burst=2, + ) + async def my_task(self) -> None: + raise Exception("Task failure") + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.1) + assert "Task 'my_task' exceeded restart burst limit" in caplog.text + assert service_instance.my_task.status == TaskStatus.NOT_RUNNING + + +@pytest.mark.asyncio(scope="function") +async def test_non_exceeding_start_limit_interval_sec_and_burst( + caplog: LogCaptureFixture, +) -> None: + class MyService(pydase.DataService): + @task( + restart_on_failure=True, + restart_sec=0.1, + start_limit_interval_sec=0.1, + start_limit_burst=2, + ) + async def my_task(self) -> None: + raise Exception("Task failure") + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.5) + assert "Task 'my_task' exceeded restart burst limit" not in caplog.text + assert service_instance.my_task.status == TaskStatus.RUNNING + + +@pytest.mark.asyncio(scope="function") +async def test_timeout_start_sec(caplog: LogCaptureFixture) -> None: + class MyService(pydase.DataService): + @task(timeout_start_sec=0.2) + async def my_task(self) -> None: + logger.info("Starting task.") + await asyncio.sleep(1) + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.1) + assert "Starting task." not in caplog.text + await asyncio.sleep(0.2) + assert "Starting task." in caplog.text + + +@pytest.mark.asyncio(scope="function") +async def test_exit_on_failure( + monkeypatch: pytest.MonkeyPatch, caplog: LogCaptureFixture +) -> None: + class MyService(pydase.DataService): + @task(restart_on_failure=False, exit_on_failure=True) + async def my_task(self) -> None: + logger.info("Triggered task.") + raise Exception("Critical failure") + + def mock_os_kill(pid: int, signal: int) -> None: + logger.critical("os.kill called with signal=%s and pid=%s", signal, pid) + + monkeypatch.setattr("os.kill", mock_os_kill) + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.1) + assert "os.kill called with signal=15 and pid=" in caplog.text + assert "Task 'my_task' encountered an exception" in caplog.text + + +@pytest.mark.asyncio(scope="function") +async def test_exit_on_failure_exceeding_rate_limit( + monkeypatch: pytest.MonkeyPatch, caplog: LogCaptureFixture +) -> None: + class MyService(pydase.DataService): + @task( + restart_on_failure=True, + restart_sec=0.0, + start_limit_interval_sec=0.1, + start_limit_burst=2, + exit_on_failure=True, + ) + async def my_task(self) -> None: + raise Exception("Critical failure") + + def mock_os_kill(pid: int, signal: int) -> None: + logger.critical("os.kill called with signal=%s and pid=%s", signal, pid) + + monkeypatch.setattr("os.kill", mock_os_kill) + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.5) + assert "os.kill called with signal=15 and pid=" in caplog.text + assert "Task 'my_task' encountered an exception" in caplog.text