From b49ff45e0a27843f555b2879f0cc6c8cf292b0b1 Mon Sep 17 00:00:00 2001 From: "Dr. Dennis Wittich" Date: Thu, 24 Oct 2024 12:49:49 +0200 Subject: [PATCH] fix tests: create only one training task --- .../tests/trainer/states/test_state_detecting.py | 4 ++-- .../trainer/states/test_state_download_train_model.py | 4 ++-- .../tests/trainer/states/test_state_prepare.py | 4 ++-- .../states/test_state_sync_confusion_matrix.py | 8 ++++---- .../tests/trainer/states/test_state_train.py | 6 +++--- .../trainer/states/test_state_upload_detections.py | 6 +++--- .../tests/trainer/states/test_state_upload_model.py | 4 ++-- learning_loop_node/tests/trainer/test_errors.py | 4 ++-- learning_loop_node/trainer/trainer_logic.py | 1 + learning_loop_node/trainer/trainer_logic_generic.py | 11 +++++++---- 10 files changed, 28 insertions(+), 24 deletions(-) diff --git a/learning_loop_node/tests/trainer/states/test_state_detecting.py b/learning_loop_node/tests/trainer/states/test_state_detecting.py index 1fd8d97d..33f9441e 100644 --- a/learning_loop_node/tests/trainer/states/test_state_detecting.py +++ b/learning_loop_node/tests/trainer/states/test_state_detecting.py @@ -37,7 +37,7 @@ async def test_detecting_can_be_aborted(test_initialized_trainer: TestingTrainer trainer._init_from_last_training() trainer.training.model_uuid_for_detecting = '12345678-bobo-7e92-f95f-424242424242' - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.Detecting, timeout=5, interval=0.001) await trainer.stop() @@ -54,7 +54,7 @@ async def test_model_not_downloadable_error(test_initialized_trainer: TestingTra model_uuid_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, 'detecting', timeout=1, interval=0.001) await assert_training_state(trainer.training, 'train_model_uploaded', timeout=1, interval=0.001) diff --git a/learning_loop_node/tests/trainer/states/test_state_download_train_model.py b/learning_loop_node/tests/trainer/states/test_state_download_train_model.py index afad1989..c1b70153 100644 --- a/learning_loop_node/tests/trainer/states/test_state_download_train_model.py +++ b/learning_loop_node/tests/trainer/states/test_state_download_train_model.py @@ -37,7 +37,7 @@ async def test_abort_download_model(test_initialized_trainer: TestingTrainerLogi create_active_training_file(trainer, training_state=TrainerState.DataDownloaded) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001) await trainer.stop() @@ -53,7 +53,7 @@ async def test_downloading_failed(test_initialized_trainer: TestingTrainerLogic) base_model_uuid_or_name='00000000-0000-0000-0000-000000000000') # bad model id) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001) await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=1, interval=0.001) diff --git a/learning_loop_node/tests/trainer/states/test_state_prepare.py b/learning_loop_node/tests/trainer/states/test_state_prepare.py index 8e512511..581549ec 100644 --- a/learning_loop_node/tests/trainer/states/test_state_prepare.py +++ b/learning_loop_node/tests/trainer/states/test_state_prepare.py @@ -30,7 +30,7 @@ async def test_abort_preparing(test_initialized_trainer: TestingTrainerLogic): create_active_training_file(trainer) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=1, interval=0.001) await trainer.stop() @@ -46,7 +46,7 @@ async def test_request_error(test_initialized_trainer: TestingTrainerLogic): organization='zauberzeug', project='some_bad_project')) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=3, interval=0.001) await assert_training_state(trainer.training, TrainerState.Initialized, timeout=3, interval=0.001) diff --git a/learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py b/learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py index 520a5e09..17f344fb 100644 --- a/learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py +++ b/learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py @@ -25,7 +25,7 @@ async def test_nothing_to_sync(test_initialized_trainer: TestingTrainerLogic): create_active_training_file(trainer, training_state=TrainerState.TrainingFinished) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001) assert trainer_has_sync_confusion_matrix_error(trainer) is False @@ -43,7 +43,7 @@ async def test_unsynced_model_available__sync_successful(test_initialized_traine trainer._init_from_last_training() trainer.has_new_model = True - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001) assert trainer_has_sync_confusion_matrix_error(trainer) is False @@ -63,7 +63,7 @@ async def test_unsynced_model_available__sio_not_connected(test_initialized_trai assert test_initialized_trainer_node.sio_client.connected is False trainer.has_new_model = True - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSyncing, timeout=1, interval=0.001) await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001) @@ -82,7 +82,7 @@ async def test_unsynced_model_available__request_is_not_successful(test_initiali create_active_training_file(trainer, training_state=TrainerState.TrainingFinished) trainer.has_new_model = True - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSyncing, timeout=1, interval=0.001) await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001) diff --git a/learning_loop_node/tests/trainer/states/test_state_train.py b/learning_loop_node/tests/trainer/states/test_state_train.py index 69371e46..76a02706 100644 --- a/learning_loop_node/tests/trainer/states/test_state_train.py +++ b/learning_loop_node/tests/trainer/states/test_state_train.py @@ -14,7 +14,7 @@ async def test_successful_training(test_initialized_trainer: TestingTrainerLogic create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01) await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.01) @@ -34,7 +34,7 @@ async def test_stop_running_training(test_initialized_trainer: TestingTrainerLog create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01) await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.01) @@ -55,7 +55,7 @@ async def test_training_can_maybe_resumed(test_initialized_trainer: TestingTrain trainer._init_from_last_training() trainer._can_resume_flag = True - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01) await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001) diff --git a/learning_loop_node/tests/trainer/states/test_state_upload_detections.py b/learning_loop_node/tests/trainer/states/test_state_upload_detections.py index 574bc196..193d8792 100644 --- a/learning_loop_node/tests/trainer/states/test_state_upload_detections.py +++ b/learning_loop_node/tests/trainer/states/test_state_upload_detections.py @@ -125,7 +125,7 @@ async def test_bad_status_from_LearningLoop(test_initialized_trainer: TestingTra trainer._init_from_last_training() trainer.active_training_io.save_detections([get_dummy_detections()]) - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001) await assert_training_state(trainer.training, TrainerState.Detected, timeout=1, interval=0.001) @@ -143,7 +143,7 @@ async def test_go_to_cleanup_if_no_detections_exist(test_initialized_trainer: Te create_active_training_file(trainer, training_state=TrainerState.Detected) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.ReadyForCleanup, timeout=1, interval=0.001) @@ -154,7 +154,7 @@ async def test_abort_uploading(test_initialized_trainer: TestingTrainerLogic): trainer._init_from_last_training() await create_valid_detection_file(trainer) - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001) diff --git a/learning_loop_node/tests/trainer/states/test_state_upload_model.py b/learning_loop_node/tests/trainer/states/test_state_upload_model.py index a1cc95e7..602b860c 100644 --- a/learning_loop_node/tests/trainer/states/test_state_upload_model.py +++ b/learning_loop_node/tests/trainer/states/test_state_upload_model.py @@ -40,7 +40,7 @@ async def test_abort_upload_model(test_initialized_trainer: TestingTrainerLogic) create_active_training_file(trainer, training_state=TrainerState.ConfusionMatrixSynced) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.TrainModelUploading, timeout=1, interval=0.001) @@ -60,7 +60,7 @@ async def test_bad_server_response_content(test_initialized_trainer: TestingTrai create_active_training_file(trainer, training_state=TrainerState.ConfusionMatrixSynced) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.TrainModelUploading, timeout=1, interval=0.001) # TODO goes to finished because of the error diff --git a/learning_loop_node/tests/trainer/test_errors.py b/learning_loop_node/tests/trainer/test_errors.py index e07fd692..eb2da89a 100644 --- a/learning_loop_node/tests/trainer/test_errors.py +++ b/learning_loop_node/tests/trainer/test_errors.py @@ -14,7 +14,7 @@ async def test_training_process_is_stopped_when_trainer_reports_error(test_initi trainer = test_initialized_trainer create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001) trainer.error_msg = 'some_error' @@ -26,7 +26,7 @@ async def test_log_can_provide_only_data_for_current_run(test_initialized_traine trainer = test_initialized_trainer create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded) trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer._run()) + trainer._begin_training_task() await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001) await asyncio.sleep(0.1) # give tests a bit time to to check for the state diff --git a/learning_loop_node/trainer/trainer_logic.py b/learning_loop_node/trainer/trainer_logic.py index 13d8dd17..b360d4bb 100644 --- a/learning_loop_node/trainer/trainer_logic.py +++ b/learning_loop_node/trainer/trainer_logic.py @@ -142,6 +142,7 @@ async def _start_training(self): async def stop(self) -> None: """If executor is running, stop it. Else cancel training task.""" print('===============> stop received in trainer_logic.', flush=True) + print(self.training_task is None, flush=True) if not self.training_active: return diff --git a/learning_loop_node/trainer/trainer_logic_generic.py b/learning_loop_node/trainer/trainer_logic_generic.py index 500e4131..d43d8b44 100644 --- a/learning_loop_node/trainer/trainer_logic_generic.py +++ b/learning_loop_node/trainer/trainer_logic_generic.py @@ -179,7 +179,7 @@ async def try_continue_run_if_incomplete(self) -> bool: if not self.training_active and self.last_training_io.exists(): self._init_from_last_training() logger.info('found incomplete training, continuing now.') - asyncio.get_event_loop().create_task(self._run()) + self._begin_training_task() return True return False @@ -195,7 +195,11 @@ async def begin_training(self, organization: str, project: str, details: Dict) - """Called on `begin_training` event from the Learning Loop. """ self._init_new_training(Context(organization=organization, project=project), details) - asyncio.get_event_loop().create_task(self._run()) + self._begin_training_task() + + def _begin_training_task(self) -> None: + # NOTE: Task object is used to potentially cancel the task + self.training_task = asyncio.get_event_loop().create_task(self._run()) def _init_new_training(self, context: Context, details: Dict) -> None: """Called on `begin_training` event from the Learning Loop. @@ -218,8 +222,7 @@ async def _run(self) -> None: """ self.errors.reset_all() try: - self.training_task = asyncio.get_running_loop().create_task(self._training_loop()) - await self.training_task # NOTE: Task object is used to potentially cancel the task + await self._training_loop() except asyncio.CancelledError: if not self.shutdown_event.is_set(): logger.info('CancelledError in _run - training task was cancelled but not by shutdown event')