Skip to content

Commit

Permalink
WIP integrate pippy's tracer frontend
Browse files Browse the repository at this point in the history
traced module is burning in a 'meta' device arg for one 'ones' op which
breaks runtime after moving model to 'cuda'.

Haven't worked on loss fn yet.

ghstack-source-id: 47735f666b6086e179699b1bbfb06168b488d4d4
Pull Request resolved: #161
  • Loading branch information
wconstab committed Mar 29, 2024
1 parent 0f49157 commit 58d932d
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 35 deletions.
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh

NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"2"}

# by default log just rank 0 output,
LOG_RANK=${LOG_RANK:-0}
Expand Down
32 changes: 30 additions & 2 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Tuple

import torch

from pippy import annotate_split_points, Pipe, PipeSplitWrapper
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -134,7 +134,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
the model must fit on GPU or CPU memory.
"""
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")
pp_mesh = world_mesh["pp"]
stage_idx = pp_mesh.get_local_rank()
layers_per_rank = len(model.layers) // parallel_dims.pp
for i in range(1, parallel_dims.pp):
annotate_split_points(
model,
{
f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING
},
)

# Get example input
label_shape = input_shape = (8, 2048) # TODO
input_ids = torch.randint(
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
)
labels = torch.randint(
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
)
print("input_ids: ", input_ids.shape, input_ids.dtype)
print("labels: ", labels.shape, labels.dtype)

# Create a pipeline representation from the model
pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,))
model = pipe.get_stage_module(stage_idx)

if parallel_dims.tp_enabled:
tp_mesh = world_mesh["tp"]
Expand Down Expand Up @@ -233,4 +257,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")

if parallel_dims.pp_enabled:
setattr(pipe.split_gm, f"submod_{stage_idx}", model)
return pipe

return model
104 changes: 73 additions & 31 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch
import torch.nn.functional as F
from pippy.PipelineSchedule import PipelineScheduleGPipe
from pippy.PipelineStage import PipelineStage
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.tensor.parallel import loss_parallel
Expand Down Expand Up @@ -197,9 +199,39 @@ def main(job_config: JobConfig):
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
# allocate sharded model on GPU and initialize weights via DTensor
model.to_empty(device="cuda")
model.init_weights()

# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
# there are virtual stages
if parallel_dims.pp_enabled:
pipe_meta = model
pp_mesh = world_mesh["pp"]
pp_degree = pp_mesh.size()
pp_rank = pp_mesh.get_local_rank()
logger.info(
f"{Color.blue}Extracting pipeline module for stage {pp_rank}{Color.reset}"
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")

model = pipe_meta.get_stage_module(pp_rank)
stage = PipelineStage(
pipe=pipe_meta,
stage_index=pp_rank,
device=device,
group=pp_mesh.get_group(),
)
pp_schedule = PipelineScheduleGPipe(
stage, n_microbatches=parallel_dims.pp, loss_fn=None
)
model.to_empty(device="cuda")
else:
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
# becuase it can't find "embedding" layer, for example.

# allocate sharded model on GPU and initialize weights via DTensor
model.to_empty(device="cuda")
model.init_weights()

# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
Expand Down Expand Up @@ -278,39 +310,49 @@ def main(job_config: JobConfig):

input_ids = input_ids.cuda()
labels = labels.cuda()

print("i", input_ids.shape)
print("l", labels.shape)
optimizer.zero_grad()

# forward
pred = model(input_ids)

with (
loss_parallel()
if parallel_dims.loss_parallel_enabled
else contextlib.nullcontext()
):
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))

# backward on scaled loss to create scaled gradients
scaler.scale(loss).backward()

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
if parallel_dims.pp_enabled:
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif pp_mesh.get_local_rank() == pp_mesh.size() - 1:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
schedule.step()
else:
# forward
pred = model(input_ids)

with (
loss_parallel()
if parallel_dims.loss_parallel_enabled
else contextlib.nullcontext()
):
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))

# backward on scaled loss to create scaled gradients
scaler.scale(loss).backward()

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)

# optimizer step
# If gradients don't contain infs/NaNs, optimizer.step() is then called;
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
scheduler.step()
# optimizer step
# If gradients don't contain infs/NaNs, optimizer.step() is then called;
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
scheduler.step()

# updates the scale for next iteration
scaler.update()
# updates the scale for next iteration
scaler.update()

current_loss = loss.item()
losses_since_last_log.append(current_loss)
current_loss = loss.item()
losses_since_last_log.append(current_loss)

# log metrics
if (
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
pipeline_parallel_degree = 2
fp8_linear = ""
compile = false
dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)
Expand Down

0 comments on commit 58d932d

Please sign in to comment.