Skip to content

Commit

Permalink
WIP integrate pippy's tracer frontend
Browse files Browse the repository at this point in the history
Loss now runs and propagates to logger, but optimizer isn't working

ghstack-source-id: 4ede08f5a9d1bc994448cb057bb491d24866d078
Pull Request resolved: #161
  • Loading branch information
wconstab committed Apr 5, 2024
1 parent 2d4de9e commit 4134d08
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 17 deletions.
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh

NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"2"}

# by default log just rank 0 output,
LOG_RANK=${LOG_RANK:-0}
Expand Down
2 changes: 1 addition & 1 deletion seed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def main(job_config: JobConfig):
if job_config.training.fp8_linear:
build_fp8_linear(model, job_config)

model.reset_parameters()
model.init_weights()

checkpoint_id = os.path.join(job_config.training.checkpoint_folder, "step-0")
logger.info(f"Creating seed (step-0) checkpoint in {checkpoint_id}")
Expand Down
32 changes: 30 additions & 2 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Tuple

import torch

from pippy import annotate_split_points, Pipe, PipeSplitWrapper
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -134,7 +134,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
the model must fit on GPU or CPU memory.
"""
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")
pp_mesh = world_mesh["pp"]
stage_idx = pp_mesh.get_local_rank()
layers_per_rank = len(model.layers) // parallel_dims.pp
for i in range(1, parallel_dims.pp):
annotate_split_points(
model,
{
f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING
},
)

# Get example input
label_shape = input_shape = (8, 2048) # TODO
input_ids = torch.randint(
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
)
labels = torch.randint(
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
)
print("input_ids: ", input_ids.shape, input_ids.dtype)
print("labels: ", labels.shape, labels.dtype)

# Create a pipeline representation from the model
pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,))
model = pipe.get_stage_module(stage_idx)

if parallel_dims.tp_enabled:
tp_mesh = world_mesh["tp"]
Expand Down Expand Up @@ -233,4 +257,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")

if parallel_dims.pp_enabled:
setattr(pipe.split_gm, f"submod_{stage_idx}", model)
return pipe

return model
72 changes: 62 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch
import torch.nn.functional as F
from pippy.PipelineSchedule import PipelineScheduleGPipe
from pippy.PipelineStage import PipelineStage
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel

Expand Down Expand Up @@ -120,7 +122,9 @@ def main(job_config: JobConfig):
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
# torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
init_distributed(job_config)

world_mesh = parallel_dims.build_mesh(device_type="cuda")
Expand All @@ -139,6 +143,14 @@ def main(job_config: JobConfig):
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0

if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]
pp_degree = pp_mesh.size()
pp_rank = pp_mesh.get_local_rank()
else:
pp_degree, pp_rank = 1, 0

data_loader = build_dataloader_fn(
job_config.training.dataset,
job_config.training.dataset_path,
Expand Down Expand Up @@ -197,14 +209,38 @@ def loss_fn(pred, labels):
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
# allocate sharded model on GPU and initialize weights via DTensor
if parallel_dims.pp_enabled:
pipe_meta = model
model = pipe_meta.get_stage_module(pp_rank)

model.to_empty(device="cuda")
model.init_weights()

# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
# there are virtual stages
if parallel_dims.pp_enabled:
stage = PipelineStage(
pipe=pipe_meta,
stage_index=pp_rank,
device=device,
group=pp_mesh.get_group(),
)
pp_schedule = PipelineScheduleGPipe(
stage,
n_microbatches=parallel_dims.pp,
loss_fn=loss_fn,
)
else:
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
# becuase it can't find "embedding" layer, for example.

# allocate sharded model on GPU and initialize weights via DTensor
model.init_weights()

# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
Expand Down Expand Up @@ -274,13 +310,30 @@ def loss_fn(pred, labels):

input_ids = input_ids.cuda()
labels = labels.cuda()

optimizer.zero_grad()

# forward / backward
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
schedule.step()

# accumulate losses across pipeline microbatches
current_loss = (
torch.mean(torch.stack(losses)).item() if is_last_stage else -1.0
)
else:
# forward / backward
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
current_loss = loss.item()

# clip gradients
torch.nn.utils.clip_grad_norm_(
Expand All @@ -291,7 +344,6 @@ def loss_fn(pred, labels):
optimizer.step()
scheduler.step()

current_loss = loss.item()
losses_since_last_log.append(current_loss)

# log metrics
Expand Down
6 changes: 3 additions & 3 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ enable_tensorboard = true
save_tb_folder = "tb"

[checkpoint]
interval = 3600
interval = 10
interval_type = "steps"
folder = ""
folder = "ckpt"

[model]
name = "llama"
Expand All @@ -38,7 +38,7 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
pipeline_parallel_degree = 2
fp8_linear = ""
compile = false
dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)
Expand Down

0 comments on commit 4134d08

Please sign in to comment.