Skip to content

Commit

Permalink
[NeMo-UX] Checkpointing fixes (#10376)
Browse files Browse the repository at this point in the history
* remove save_best_model from default logger

Signed-off-by: ashors1 <[email protected]>

* fix broken checkpoint restore

Signed-off-by: ashors1 <[email protected]>

* fix fsdp

Signed-off-by: ashors1 <[email protected]>

* rename weights path to avoid confusion

Signed-off-by: ashors1 <[email protected]>

* Revert "rename weights path to avoid confusion". We'll add this in a separate PR

This reverts commit 72bae8b.

---------

Signed-off-by: ashors1 <[email protected]>
  • Loading branch information
ashors1 authored Sep 7, 2024
1 parent 62c1dce commit 9e372d3
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 3 deletions.
1 change: 0 additions & 1 deletion nemo/collections/llm/recipes/log/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def default_log(
) -> Config[nl.NeMoLogger]:
ckpt = Config(
nl.ModelCheckpoint,
save_best_model=False,
save_last=True,
save_top_k=10,
every_n_train_steps=200,
Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/pytorch/strategies/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def save_checkpoint(
and self.trainer.state.fn == TrainerFn.FITTING
and self.ckpt_save_optimizer
):
del checkpoint["optimizer_states"]
checkpoint["optimizer_states"] = {}
checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers)
pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.")

Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def save_checkpoint(
and self.trainer.state.fn == TrainerFn.FITTING
and self.ckpt_save_optimizer
):
del checkpoint["optimizer_states"]
checkpoint["optimizer_states"] = {}
checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()]

self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
Expand Down

0 comments on commit 9e372d3

Please sign in to comment.