Skip to content

Commit

Permalink
update epochs to save latest checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 28, 2024
1 parent c38b076 commit 7331774
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,15 @@ def compute_log_probs(model, example):
train_loader = iter(train_loader)

## OK, actually run training!
trainer.train(state, train_loader)
last_info = trainer.train(state, train_loader)

# checkpointer.on_step(last_step, force=True)


# If running EpochDataset save latest checkpoint by default
if trainer.config.checkpointer is not None and config.epoch > 0:
trainer.run_hooks(last_info, force=True)
checkpointer = trainer.config.checkpointer.create(trainer.run_id)
checkpointer.wait_until_finished()


if __name__ == "__main__":
Expand Down

0 comments on commit 7331774

Please sign in to comment.