Skip to content

Commit

Permalink
better in-training checkpoint selector
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Fitzjalen authored and Roman Fitzjalen committed Oct 22, 2024
1 parent 1e2aa0a commit 3c499ac
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions nnunetv2/run/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,18 @@ def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool
raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only '
'be used at the beginning of the training.')
if continue_training:
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
if not isfile(expected_checkpoint_file):
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth')
# special case where --c is used to run a previously aborted validation
if not isfile(expected_checkpoint_file):
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth')
if not isfile(expected_checkpoint_file):
print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to "
f"continue from. Starting a new training...")
checkpoint_files = [
join(nnunet_trainer.output_folder, 'checkpoint_final.pth'),
join(nnunet_trainer.output_folder, 'checkpoint_latest.pth'),
join(nnunet_trainer.output_folder, 'checkpoint_best.pth'),
]
# Filter out the files that actually exist
existing_checkpoints = [ckpt for ckpt in checkpoint_files if isfile(ckpt)]
if existing_checkpoints:
# Select the checkpoint with the most recent modification time
expected_checkpoint_file = max(existing_checkpoints, key=os.path.getmtime)
else:
print("WARNING: Cannot continue training because there seems to be no checkpoint available to continue from. Starting a new training...")
expected_checkpoint_file = None
elif validation_only:
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
Expand Down

0 comments on commit 3c499ac

Please sign in to comment.