Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable checkpointing with DCP #26

Merged
merged 5 commits into from
Feb 6, 2024
Merged

Enable checkpointing with DCP #26

merged 5 commits into from
Feb 6, 2024

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Jan 31, 2024

Summary:
This PR enable checkpointing. The PR only enables checkpointing in the local storages. Only when DCP enables automatic storage detection can this checkpoint manager support remote storages.

This PR didn't checkpoint dataloader.

Test Plan:
Changed CHECKPOINT_FOLDER to /tmp/checkpoint_chienchin and ran ./run_llama_train.sh twice. The first run ran through all 100 steps and the checkpoints were saved. The second run loaded the checkpoint back and detected the saved step count is 100. No training was done for the second step.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 31, 2024
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! thanks for doing this super fast! I have a few suggestions inlined.

@@ -7,6 +7,11 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
MODEL="debugmodel"
NGPU=8
MP=4
# Change this string to a meaningful one to enable checkpoint
CHECKPOINT_FOLDER=""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we change this to something like a /tmp/torchtrain so that it saves somewhere when we locally run it?

Copy link
Contributor Author

@fegin fegin Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should set this be an opt-in feature so that people won't get surprise and may save too many files to /tmp when people are using the same machine. And since if the training finishes, there will be a checkpoint, users may unconsciously ignore all the new training because of an existing checkpoint with last_step. That happens a lot. So it's better to do an opt-in feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense

train.py Outdated
)
parser.add_argument(
"--checkpoint-interval-type",
type=str, default="seconds",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to use interval type seconds here? I think maybe a simple checkpoint-interval that documented as number of iterations should be enough? as we ultimately don't know how much time a model fwd/bwd/optim time would take, I think a number of iterations is more sound than seconds.

Copy link
Contributor Author

@fegin fegin Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep the time feature as described above but use step as the default one.

train.py Outdated
rank0_log(f"current loss: {train_state.current_loss}")

checkpoint.save(train_state.step)

if train_state.step == args.steps:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc this is after all steps we save a final checkpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

and (curr_step - self.begin) < self.interval
):
return
if self.interval_type == IntervalType.SECONDS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer we get rid of the seconds handling as mentioned in another comment to keep our stack simple enough.

We can add it back once we feel this mode is needed for the actual training

Copy link
Contributor Author

@fegin fegin Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is sometimes better to use time because a lot of features can change the per-iteration time like model type, batch size and other stuffs. Using steps may require some tuning to avoid affect the overall performance.

We can change the default to steps so that users don't need to worry about it now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure sounds good! My main motivation is that our library to be as simple as possible, we can evaluate once we start real trainings if we would use time interval type, and decide later whether we want to keep it or not

)

def load(self, step: int = -1) -> bool:
if not self.folder:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should check either in train.py or in this save and load method to only save/load when step % checkpoint_interval == 0, so that we skip the save/load logic when we don't need to save/load checkpoints.

@fegin fegin force-pushed the chienchin_enable_checkpoint branch from a0257bc to 017680f Compare February 1, 2024 23:57
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! have a few more minor comments inlined

@@ -30,6 +31,18 @@ class TrainState:
current_loss: float = -1
losses: List[float] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to avoid confusion with the model/optim state dict, we should rename this to sth like train_state

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is naming is required by DCP.

"losses": torch.tensor(self.current_loss, dtype=torch.float32),
}

def load_state_dict(self, state_dict) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: load_train_state to avoid confusion with DCP.save/load_state_dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is naming is required by DCP.

f"{time.monotonic() - begin} seconds"
)

def load(self, step: int = -1) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we have a step arg here? seems like we don't use this arg too, we should remove it first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do use the step. In the case where there are more than one checkpoint saved, users can specify the step to load a specific checkpoint.

and (curr_step - self.begin) < self.interval
):
return
if self.interval_type == IntervalType.SECONDS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure sounds good! My main motivation is that our library to be as simple as possible, we can evaluate once we start real trainings if we would use time interval type, and decide later whether we want to keep it or not

@@ -7,6 +7,11 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
MODEL="debugmodel"
NGPU=8
MP=4
# Change this string to a meaningful one to enable checkpoint
CHECKPOINT_FOLDER=""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense

@wconstab
Copy link
Contributor

wconstab commented Feb 2, 2024

could you try rebasing before you merge? you'll pick up the linter CI that way and itll force you to lint your new files. Hopefully not too many conflicts to resolve since most of the hairy linter changes were in model.py

fegin added 3 commits February 5, 2024 11:31
Summary:
This PR enable checkpointing. The PR only enables checkpointing in the local storages. Only when DCP enables automatic storage detection can this checkpoint manager support remote storages.

This PR didn't checkpoint dataloader.

Test Plan:
Changed CHECKPOINT_FOLDER to /tmp/checkpoint_chienchin and ran ./run_llama_train.sh twice. The first run ran through all 100 steps and the checkpoints were saved. The second run loaded the checkpoint back and detected the saved step count is 100. No training was done for the second step.

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@fegin fegin force-pushed the chienchin_enable_checkpoint branch from 0541be2 to fe2e1c6 Compare February 5, 2024 19:34
fegin added 2 commits February 5, 2024 11:39
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@fegin fegin merged commit 6bd9082 into main Feb 6, 2024
3 checks passed
lessw2020 pushed a commit that referenced this pull request Apr 18, 2024
Summary:
This PR enable checkpointing. The PR only enables checkpointing in the
local storages. Only when DCP enables automatic storage detection can
this checkpoint manager support remote storages.

This PR didn't checkpoint dataloader.

Test Plan:
Changed CHECKPOINT_FOLDER to /tmp/checkpoint_chienchin and ran
./run_llama_train.sh twice. The first run ran through all 100 steps and
the checkpoints were saved. The second run loaded the checkpoint back
and detected the saved step count is 100. No training was done for the
second step.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants