Skip to content

Commit

Permalink
Enable checkpointing with DCP (#26)
Browse files Browse the repository at this point in the history
Summary:
This PR enable checkpointing. The PR only enables checkpointing in the
local storages. Only when DCP enables automatic storage detection can
this checkpoint manager support remote storages.

This PR didn't checkpoint dataloader.

Test Plan:
Changed CHECKPOINT_FOLDER to /tmp/checkpoint_chienchin and ran
./run_llama_train.sh twice. The first run ran through all 100 steps and
the checkpoints were saved. The second run loaded the checkpoint back
and detected the saved step count is 100. No training was done for the
second step.
  • Loading branch information
fegin authored Feb 6, 2024
1 parent 479a571 commit 3d27c70
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 2 deletions.
7 changes: 6 additions & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
MODEL="debugmodel"
NGPU=8
MP=4
# Change this string to a meaningful one to enable checkpoint
CHECKPOINT_FOLDER=""
# Please adjust this to a longer interval period. The unit of measurement is in steps.
CHECKPOINT_INTERVAL=5

torchrun --nproc_per_node=${NGPU} \
train.py --steps 10
train.py --steps 10 --compile \
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
146 changes: 146 additions & 0 deletions torchtrain/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import enum
import os
import re
import time
from typing import Any, Dict

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torchtrain.logging_utils import rank0_log


class IntervalType(enum.Enum):
SECONDS = enum.auto()
STEPS = enum.auto()


class ModelWrapper:
def __init__(self, model: nn.Module) -> None:
self.model = model

def state_dict(self) -> None:
return get_model_state_dict(self.model)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_model_state_dict(self.model, state_dict)


class OptimizerWrapper:
def __init__(self, model: nn.Module, optim: torch.optim.Optimizer) -> None:
self.model = model
self.optim = optim

def state_dict(self) -> None:
return get_optimizer_state_dict(self.model, self.optim)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict)


class CheckpointManager:
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
states: Dict[str, Any],
folder: str,
interval_type: IntervalType,
interval: int,
) -> None:
self.folder = folder
self.states = states
self.states.update(
{
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
}
)
self.interval_type = interval_type
self.interval = interval
self.begin = 0
self.work = None
self.pg = dist.new_group(backend="gloo")
self.doit = None

def reset(self) -> None:
self.begin = time.monotonic()

def create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")

def save(self, curr_step: int, force: bool = False) -> None:
if not self.folder:
return

if not force:
if self.interval_type == IntervalType.STEPS and not (
curr_step % self.interval == 0
):
return
if self.interval_type == IntervalType.SECONDS:
doit = (time.monotonic() - self.begin) >= self.interval
self.doit = torch.tensor(int(doit))
if self.work is None:
self.work = dist.all_reduce(self.doit, group=self.pg, async_op=True)
return
elif curr_step % 5 == 4:
self.work.wait()
self.work = None
doit = self.doit.item()
self.doit = None
if doit == 0:
return
else:
return

if self.work:
self.work.wait()
self.work = None
self.doit = None

rank0_log(f"Saving a checkpoint in step {curr_step}.")
begin = time.monotonic()
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step))
self.reset()
rank0_log(
f"Finish saving the checkpoint in step {curr_step}. "
f"{time.monotonic() - begin} seconds"
)

def load(self, step: int = -1) -> bool:
if not self.folder:
return False
if not os.path.isdir(self.folder):
return False
if step != -1 and not os.path.isdir(self.create_checkpoint_id(step)):
return False

if step == -1:
step_counts = []
for filename in os.listdir(self.folder):
match = re.search(r"step-(\d+)", filename)
if match:
step_counts.append(int(match.group(1)))
if not step_counts:
return False
step = max(step_counts)

rank0_log("Loading a checkpoint.")
begin = time.monotonic()
dcp.load(
self.states,
checkpoint_id=self.create_checkpoint_id(step),
)
rank0_log(f"Finish loading a checkpoint. {time.monotonic() - begin} seconds.")
return True
60 changes: 59 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import argparse
import os
from dataclasses import dataclass, field
from typing import List, Union
from typing import Any, Dict, List, Union

# torch imports
import torch
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

from torchtrain.checkpoint import CheckpointManager, IntervalType

# torchtrain related
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
Expand All @@ -29,6 +31,18 @@ class TrainState:
current_loss: float = -1
losses: List[float] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
return {
"step": torch.tensor(self.step, dtype=torch.int32),
"current_loss": torch.tensor(self.current_loss, dtype=torch.float32),
"losses": torch.tensor(self.current_loss, dtype=torch.float32),
}

def load_state_dict(self, state_dict) -> None:
self.step = state_dict["step"].item()
self.current_loss = state_dict["current_loss"].item()
self.losses = state_dict["losses"].tolist()


def build_optimizer(model, args):
# build optimizer
Expand Down Expand Up @@ -116,7 +130,22 @@ def main(args):
# train loop
model.train()

checkpoint = CheckpointManager(
model=model,
optimizer=optimizer,
states={"train_state": train_state},
folder=args.checkpoint_folder,
interval_type=(
IntervalType.SECONDS
if args.checkpoint_interval_type == "seconds"
else IntervalType.STEPS
),
interval=args.checkpoint_interval,
)
checkpoint.load()

with maybe_run_profiler() as torch_profiler:
checkpoint.reset()
while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
# get batch
Expand Down Expand Up @@ -161,6 +190,8 @@ def main(args):
)
scheduler.step()

checkpoint.save(train_state.step, force=(train_state.step == args.steps))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
Expand Down Expand Up @@ -224,6 +255,33 @@ def main(args):
parser.add_argument(
"--compile", action="store_true", help="Whether to compile the model."
)
parser.add_argument(
"--checkpoint-interval",
type=int,
default=3600,
help=(
"Checkpointing interval. The unit of measurement is in seconds or "
"steps depending on --checkpoint-internval-type."
),
)
parser.add_argument(
"--checkpoint-interval-type",
type=str,
default="steps",
help=(
"The checkpointing interval unit of measurement."
"The default value is step."
),
)
parser.add_argument(
"--checkpoint-folder",
type=str,
default="",
help=(
"The folder to store the checkpoints. If this is not specified or "
"is an empty string, checkpointing is disabled."
),
)

args = parser.parse_args()
main(args)

0 comments on commit 3d27c70

Please sign in to comment.