Skip to content

Commit

Permalink
Add support for seed checkpoint creation for meta-init flow
Browse files Browse the repository at this point in the history
ghstack-source-id: 39c4ec84e56c60ee831d9b861ac118a2d4cedd08
Pull Request resolved: #172
  • Loading branch information
wconstab committed Apr 5, 2024
1 parent a02eb33 commit 2d4de9e
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}

torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE}
train.py --job.config_file ${CONFIG_FILE} --training.checkpoint_folder /data/users/whc/torchtrain
61 changes: 61 additions & 0 deletions seed_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os

import torch.distributed.checkpoint as DCP

from torchtrain.config_manager import JobConfig
from torchtrain.datasets import create_tokenizer
from torchtrain.float8_linear import build_fp8_linear
from torchtrain.logging_utils import init_logger, logger
from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config

_is_local_logging = True
if "SLURM_JOB_ID" in os.environ:
_is_local_logging = False


def main(job_config: JobConfig):
init_logger()

model_name = job_config.model.name

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
model_config.vocab_size = tokenizer.n_words
logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}")
model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(model, job_config)

model.reset_parameters()

checkpoint_id = os.path.join(job_config.training.checkpoint_folder, "step-0")
logger.info(f"Creating seed (step-0) checkpoint in {checkpoint_id}")
DCP.save(
state_dict={
"model": model.state_dict(),
},
checkpoint_id=checkpoint_id,
)


"""
1. how do i serialize enough info about the model config to ensure i don't try to load an incompatible checkpoint later?
- maybe skip this. users responsible to manage their checkpoints, and we can partially help by managing their 'dump folder'?
2. would i apply fp8 before creating the seed or not? I think probably before
3. can i skip optimizer in seed file? i think so. optimizer can later create its states from the model post-sharding
"""
if __name__ == "__main__":
config = JobConfig()
config.parse_args()
main(config)
5 changes: 4 additions & 1 deletion torchtrain/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
)

def load(self, step: int = -1) -> bool:
logger.info(f"Trying Loading a checkpoint from '{self.folder}'")
if not self.folder:
return False
if not os.path.isdir(self.folder):
Expand All @@ -140,10 +141,12 @@ def load(self, step: int = -1) -> bool:
return False
step = max(step_counts)

# We won't have optimizer states to load, if we are loading a seed checkpoint
states = {"model": self.states["model"]} if step == 0 else self.states
logger.info(f"Loading the checkpoint at step {step}")
begin = time.monotonic()
dcp.load(
self.states,
states,
checkpoint_id=self.create_checkpoint_id(step),
)
logger.info(
Expand Down

0 comments on commit 2d4de9e

Please sign in to comment.