Skip to content

Commit

Permalink
fix test_ensure_all_detections_are_uploaded (behaviour of the tested …
Browse files Browse the repository at this point in the history
…function was corrected earlier)
  • Loading branch information
denniswittich committed Apr 26, 2024
1 parent e5da0ba commit e74a2ce
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
11 changes: 4 additions & 7 deletions learning_loop_node/trainer/io_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ async def _upload_detections_batched(self, context: Context, detections: List[De
up_count = 0
for i in range(skip_detections, len(detections), batch_size):
up_count += 1
up_progress = i+batch_size
up_progress = i + batch_size if i + batch_size < len(detections) else 0
batch_detections = detections[i:up_progress]
await self._upload_detections(context, batch_detections, up_progress)
await self._upload_detections_and_save_progress(context, batch_detections, up_progress)
skip_detections = up_progress

logging.info('uploaded %d detections', len(detections))

async def _upload_detections(self, context: Context, batch_detections: List[Detections], up_progress: int):
async def _upload_detections_and_save_progress(self, context: Context, batch_detections: List[Detections], up_progress: int):
detections_json = [jsonable_encoder(asdict(detections)) for detections in batch_detections]
response = await self.loop_communicator.post(
f'/{context.organization}/projects/{context.project}/detections', json=detections_json)
Expand All @@ -179,7 +179,4 @@ async def _upload_detections(self, context: Context, batch_detections: List[Dete
raise Exception(msg)

logging.info('successfully uploaded detections')
if up_progress >= len(batch_detections):
self.save_detection_upload_progress(0)
else:
self.save_detection_upload_progress(up_progress)
self.save_detection_upload_progress(up_progress)
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def create_valid_detection_file(trainer: TrainerLogic, number_of_entries:
'point_detections': [], 'segmentation_detections': []})
detections = [detection_entry] * number_of_entries

assert trainer.active_training_io is not None # pylint: disable=protected-access
assert trainer.active_training_io is not None
trainer.active_training_io.save_detections(detections, file_index)


Expand Down Expand Up @@ -80,21 +80,21 @@ async def test_ensure_all_detections_are_uploaded(test_initialized_trainer: Test
create_active_training_file(trainer, training_state=TrainerState.Detected)
trainer._init_from_last_training()

await create_valid_detection_file(trainer, 2, 0)
await create_valid_detection_file(trainer, 2, 1)
await create_valid_detection_file(trainer, 4, 0)
await create_valid_detection_file(trainer, 4, 1)

assert trainer.active_training_io.load_detections_upload_file_index() == 0
detections = trainer.active_training_io.load_detections(0)
assert len(detections) == 2
assert len(detections) == 4

batch_size = 1
batch_size = 2
skip_detections = trainer.active_training_io.load_detection_upload_progress()
for i in range(skip_detections, len(detections), batch_size):
batch_detections = detections[i:i+batch_size]
# pylint: disable=protected-access
await trainer.active_training_io._upload_detections(trainer.training.context, batch_detections, i + batch_size)
progress = i + batch_size if i + batch_size < len(detections) else 0
await trainer.active_training_io._upload_detections_and_save_progress(trainer.training.context, batch_detections, progress)

expected_value = i + batch_size if i + batch_size < len(detections) else 0 # Progress is reset for every file
expected_value = progress # Progress is reset for every file
assert trainer.active_training_io.load_detection_upload_progress() == expected_value

assert trainer.active_training_io.load_detections_upload_file_index() == 0
Expand All @@ -107,10 +107,11 @@ async def test_ensure_all_detections_are_uploaded(test_initialized_trainer: Test
skip_detections = trainer.active_training_io.load_detection_upload_progress()
for i in range(skip_detections, len(detections), batch_size):
batch_detections = detections[i:i+batch_size]
# pylint: disable=protected-access
await trainer.active_training_io._upload_detections(trainer.training.context, batch_detections, i + batch_size)

expected_value = i + batch_size if i + batch_size < len(detections) else 0 # Progress is reset for every file
progress = i + batch_size if i + batch_size < len(detections) else 0
await trainer.active_training_io._upload_detections_and_save_progress(trainer.training.context, batch_detections, progress)

expected_value = progress # Progress is reset for every file
assert trainer.active_training_io.load_detection_upload_progress() == expected_value
assert trainer.active_training_io.load_detections_upload_file_index() == 1

Expand Down Expand Up @@ -160,5 +161,5 @@ async def test_abort_uploading(test_initialized_trainer: TestingTrainerLogic):
await trainer.stop()
await asyncio.sleep(0.1)

assert trainer._training is None # pylint: disable=protected-access
assert trainer._training is None
assert trainer.node.last_training_io.exists() is False

0 comments on commit e74a2ce

Please sign in to comment.