From e74a2ce47961197962f5b135e3e0175e051b14e9 Mon Sep 17 00:00:00 2001 From: "Dr. Dennis Wittich" Date: Fri, 26 Apr 2024 15:35:11 +0200 Subject: [PATCH] fix test_ensure_all_detections_are_uploaded (behaviour of the tested function was corrected earlier) --- learning_loop_node/trainer/io_helpers.py | 11 +++----- .../states/test_state_upload_detections.py | 25 ++++++++++--------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/learning_loop_node/trainer/io_helpers.py b/learning_loop_node/trainer/io_helpers.py index 852b0f70..cb0db600 100644 --- a/learning_loop_node/trainer/io_helpers.py +++ b/learning_loop_node/trainer/io_helpers.py @@ -162,14 +162,14 @@ async def _upload_detections_batched(self, context: Context, detections: List[De up_count = 0 for i in range(skip_detections, len(detections), batch_size): up_count += 1 - up_progress = i+batch_size + up_progress = i + batch_size if i + batch_size < len(detections) else 0 batch_detections = detections[i:up_progress] - await self._upload_detections(context, batch_detections, up_progress) + await self._upload_detections_and_save_progress(context, batch_detections, up_progress) skip_detections = up_progress logging.info('uploaded %d detections', len(detections)) - async def _upload_detections(self, context: Context, batch_detections: List[Detections], up_progress: int): + async def _upload_detections_and_save_progress(self, context: Context, batch_detections: List[Detections], up_progress: int): detections_json = [jsonable_encoder(asdict(detections)) for detections in batch_detections] response = await self.loop_communicator.post( f'/{context.organization}/projects/{context.project}/detections', json=detections_json) @@ -179,7 +179,4 @@ async def _upload_detections(self, context: Context, batch_detections: List[Dete raise Exception(msg) logging.info('successfully uploaded detections') - if up_progress >= len(batch_detections): - self.save_detection_upload_progress(0) - else: - self.save_detection_upload_progress(up_progress) + self.save_detection_upload_progress(up_progress) diff --git a/learning_loop_node/trainer/tests/states/test_state_upload_detections.py b/learning_loop_node/trainer/tests/states/test_state_upload_detections.py index e2784514..ca588282 100644 --- a/learning_loop_node/trainer/tests/states/test_state_upload_detections.py +++ b/learning_loop_node/trainer/tests/states/test_state_upload_detections.py @@ -37,7 +37,7 @@ async def create_valid_detection_file(trainer: TrainerLogic, number_of_entries: 'point_detections': [], 'segmentation_detections': []}) detections = [detection_entry] * number_of_entries - assert trainer.active_training_io is not None # pylint: disable=protected-access + assert trainer.active_training_io is not None trainer.active_training_io.save_detections(detections, file_index) @@ -80,21 +80,21 @@ async def test_ensure_all_detections_are_uploaded(test_initialized_trainer: Test create_active_training_file(trainer, training_state=TrainerState.Detected) trainer._init_from_last_training() - await create_valid_detection_file(trainer, 2, 0) - await create_valid_detection_file(trainer, 2, 1) + await create_valid_detection_file(trainer, 4, 0) + await create_valid_detection_file(trainer, 4, 1) assert trainer.active_training_io.load_detections_upload_file_index() == 0 detections = trainer.active_training_io.load_detections(0) - assert len(detections) == 2 + assert len(detections) == 4 - batch_size = 1 + batch_size = 2 skip_detections = trainer.active_training_io.load_detection_upload_progress() for i in range(skip_detections, len(detections), batch_size): batch_detections = detections[i:i+batch_size] - # pylint: disable=protected-access - await trainer.active_training_io._upload_detections(trainer.training.context, batch_detections, i + batch_size) + progress = i + batch_size if i + batch_size < len(detections) else 0 + await trainer.active_training_io._upload_detections_and_save_progress(trainer.training.context, batch_detections, progress) - expected_value = i + batch_size if i + batch_size < len(detections) else 0 # Progress is reset for every file + expected_value = progress # Progress is reset for every file assert trainer.active_training_io.load_detection_upload_progress() == expected_value assert trainer.active_training_io.load_detections_upload_file_index() == 0 @@ -107,10 +107,11 @@ async def test_ensure_all_detections_are_uploaded(test_initialized_trainer: Test skip_detections = trainer.active_training_io.load_detection_upload_progress() for i in range(skip_detections, len(detections), batch_size): batch_detections = detections[i:i+batch_size] - # pylint: disable=protected-access - await trainer.active_training_io._upload_detections(trainer.training.context, batch_detections, i + batch_size) - expected_value = i + batch_size if i + batch_size < len(detections) else 0 # Progress is reset for every file + progress = i + batch_size if i + batch_size < len(detections) else 0 + await trainer.active_training_io._upload_detections_and_save_progress(trainer.training.context, batch_detections, progress) + + expected_value = progress # Progress is reset for every file assert trainer.active_training_io.load_detection_upload_progress() == expected_value assert trainer.active_training_io.load_detections_upload_file_index() == 1 @@ -160,5 +161,5 @@ async def test_abort_uploading(test_initialized_trainer: TestingTrainerLogic): await trainer.stop() await asyncio.sleep(0.1) - assert trainer._training is None # pylint: disable=protected-access + assert trainer._training is None assert trainer.node.last_training_io.exists() is False