Skip to content

Commit

Permalink
fix tests: create only one training task
Browse files Browse the repository at this point in the history
  • Loading branch information
denniswittich committed Oct 24, 2024
1 parent 5b82697 commit b49ff45
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def test_detecting_can_be_aborted(test_initialized_trainer: TestingTrainer
trainer._init_from_last_training()
trainer.training.model_uuid_for_detecting = '12345678-bobo-7e92-f95f-424242424242'

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, TrainerState.Detecting, timeout=5, interval=0.001)
await trainer.stop()
Expand All @@ -54,7 +54,7 @@ async def test_model_not_downloadable_error(test_initialized_trainer: TestingTra
model_uuid_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, 'detecting', timeout=1, interval=0.001)
await assert_training_state(trainer.training, 'train_model_uploaded', timeout=1, interval=0.001)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def test_abort_download_model(test_initialized_trainer: TestingTrainerLogi
create_active_training_file(trainer, training_state=TrainerState.DataDownloaded)
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()
await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)

await trainer.stop()
Expand All @@ -53,7 +53,7 @@ async def test_downloading_failed(test_initialized_trainer: TestingTrainerLogic)
base_model_uuid_or_name='00000000-0000-0000-0000-000000000000') # bad model id)
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()
await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=1, interval=0.001)

Expand Down
4 changes: 2 additions & 2 deletions learning_loop_node/tests/trainer/states/test_state_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_abort_preparing(test_initialized_trainer: TestingTrainerLogic):
create_active_training_file(trainer)
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()
await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=1, interval=0.001)

await trainer.stop()
Expand All @@ -46,7 +46,7 @@ async def test_request_error(test_initialized_trainer: TestingTrainerLogic):
organization='zauberzeug', project='some_bad_project'))
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()
await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=3, interval=0.001)
await assert_training_state(trainer.training, TrainerState.Initialized, timeout=3, interval=0.001)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def test_nothing_to_sync(test_initialized_trainer: TestingTrainerLogic):
create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001)
assert trainer_has_sync_confusion_matrix_error(trainer) is False
Expand All @@ -43,7 +43,7 @@ async def test_unsynced_model_available__sync_successful(test_initialized_traine
trainer._init_from_last_training()
trainer.has_new_model = True

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()
await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001)

assert trainer_has_sync_confusion_matrix_error(trainer) is False
Expand All @@ -63,7 +63,7 @@ async def test_unsynced_model_available__sio_not_connected(test_initialized_trai
assert test_initialized_trainer_node.sio_client.connected is False
trainer.has_new_model = True

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSyncing, timeout=1, interval=0.001)
await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001)
Expand All @@ -82,7 +82,7 @@ async def test_unsynced_model_available__request_is_not_successful(test_initiali
create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)

trainer.has_new_model = True
_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSyncing, timeout=1, interval=0.001)
await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001)
Expand Down
6 changes: 3 additions & 3 deletions learning_loop_node/tests/trainer/states/test_state_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_successful_training(test_initialized_trainer: TestingTrainerLogic
create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded)
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01)
await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.01)
Expand All @@ -34,7 +34,7 @@ async def test_stop_running_training(test_initialized_trainer: TestingTrainerLog
create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded)
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01)
await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.01)
Expand All @@ -55,7 +55,7 @@ async def test_training_can_maybe_resumed(test_initialized_trainer: TestingTrain
trainer._init_from_last_training()
trainer._can_resume_flag = True

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01)
await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def test_bad_status_from_LearningLoop(test_initialized_trainer: TestingTra
trainer._init_from_last_training()
trainer.active_training_io.save_detections([get_dummy_detections()])

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()
await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001)
await assert_training_state(trainer.training, TrainerState.Detected, timeout=1, interval=0.001)

Expand All @@ -143,7 +143,7 @@ async def test_go_to_cleanup_if_no_detections_exist(test_initialized_trainer: Te
create_active_training_file(trainer, training_state=TrainerState.Detected)
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()
await assert_training_state(trainer.training, TrainerState.ReadyForCleanup, timeout=1, interval=0.001)


Expand All @@ -154,7 +154,7 @@ async def test_abort_uploading(test_initialized_trainer: TestingTrainerLogic):
trainer._init_from_last_training()
await create_valid_detection_file(trainer)

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def test_abort_upload_model(test_initialized_trainer: TestingTrainerLogic)
create_active_training_file(trainer, training_state=TrainerState.ConfusionMatrixSynced)
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, TrainerState.TrainModelUploading, timeout=1, interval=0.001)

Expand All @@ -60,7 +60,7 @@ async def test_bad_server_response_content(test_initialized_trainer: TestingTrai
create_active_training_file(trainer, training_state=TrainerState.ConfusionMatrixSynced)
trainer._init_from_last_training()

_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, TrainerState.TrainModelUploading, timeout=1, interval=0.001)
# TODO goes to finished because of the error
Expand Down
4 changes: 2 additions & 2 deletions learning_loop_node/tests/trainer/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_training_process_is_stopped_when_trainer_reports_error(test_initi
trainer = test_initialized_trainer
create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded)
trainer._init_from_last_training()
_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001)
trainer.error_msg = 'some_error'
Expand All @@ -26,7 +26,7 @@ async def test_log_can_provide_only_data_for_current_run(test_initialized_traine
trainer = test_initialized_trainer
create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded)
trainer._init_from_last_training()
_ = asyncio.get_running_loop().create_task(trainer._run())
trainer._begin_training_task()

await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001)
await asyncio.sleep(0.1) # give tests a bit time to to check for the state
Expand Down
1 change: 1 addition & 0 deletions learning_loop_node/trainer/trainer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ async def _start_training(self):
async def stop(self) -> None:
"""If executor is running, stop it. Else cancel training task."""
print('===============> stop received in trainer_logic.', flush=True)
print(self.training_task is None, flush=True)

if not self.training_active:
return
Expand Down
11 changes: 7 additions & 4 deletions learning_loop_node/trainer/trainer_logic_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def try_continue_run_if_incomplete(self) -> bool:
if not self.training_active and self.last_training_io.exists():
self._init_from_last_training()
logger.info('found incomplete training, continuing now.')
asyncio.get_event_loop().create_task(self._run())
self._begin_training_task()
return True
return False

Expand All @@ -195,7 +195,11 @@ async def begin_training(self, organization: str, project: str, details: Dict) -
"""Called on `begin_training` event from the Learning Loop.
"""
self._init_new_training(Context(organization=organization, project=project), details)
asyncio.get_event_loop().create_task(self._run())
self._begin_training_task()

def _begin_training_task(self) -> None:
# NOTE: Task object is used to potentially cancel the task
self.training_task = asyncio.get_event_loop().create_task(self._run())

def _init_new_training(self, context: Context, details: Dict) -> None:
"""Called on `begin_training` event from the Learning Loop.
Expand All @@ -218,8 +222,7 @@ async def _run(self) -> None:
"""
self.errors.reset_all()
try:
self.training_task = asyncio.get_running_loop().create_task(self._training_loop())
await self.training_task # NOTE: Task object is used to potentially cancel the task
await self._training_loop()
except asyncio.CancelledError:
if not self.shutdown_event.is_set():
logger.info('CancelledError in _run - training task was cancelled but not by shutdown event')
Expand Down

0 comments on commit b49ff45

Please sign in to comment.