From 58d932d35b9bbaa710c6296ba42cad3920b6dd83 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 29 Mar 2024 14:59:16 -0700 Subject: [PATCH] WIP integrate pippy's tracer frontend traced module is burning in a 'meta' device arg for one 'ones' op which breaks runtime after moving model to 'cuda'. Haven't worked on loss fn yet. ghstack-source-id: 47735f666b6086e179699b1bbfb06168b488d4d4 Pull Request resolved: https://github.com/pytorch/torchtrain/pull/161 --- run_llama_train.sh | 2 +- torchtrain/parallelisms/parallelize_llama.py | 32 +++++- train.py | 104 +++++++++++++------ train_configs/debug_model.toml | 2 +- 4 files changed, 105 insertions(+), 35 deletions(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index a46cf5cf3..0dcc5c277 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain} # e.g. # LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh -NGPU=${NGPU:-"8"} +NGPU=${NGPU:-"2"} # by default log just rank 0 output, LOG_RANK=${LOG_RANK:-0} diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index e64267c50..53292c405 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -8,7 +8,7 @@ from typing import Tuple import torch - +from pippy import annotate_split_points, Pipe, PipeSplitWrapper from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -134,7 +134,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): the model must fit on GPU or CPU memory. """ if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") + pp_mesh = world_mesh["pp"] + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + for i in range(1, parallel_dims.pp): + annotate_split_points( + model, + { + f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING + }, + ) + + # Get example input + label_shape = input_shape = (8, 2048) # TODO + input_ids = torch.randint( + model.vocab_size, input_shape, dtype=torch.int64, device="meta" + ) + labels = torch.randint( + model.vocab_size, label_shape, dtype=torch.int64, device="meta" + ) + print("input_ids: ", input_ids.shape, input_ids.dtype) + print("labels: ", labels.shape, labels.dtype) + + # Create a pipeline representation from the model + pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,)) + model = pipe.get_stage_module(stage_idx) if parallel_dims.tp_enabled: tp_mesh = world_mesh["tp"] @@ -233,4 +257,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): logger.info(f"Applied {ac_mode} activation checkpointing to the model") logger.info("Applied FSDP to the model") + if parallel_dims.pp_enabled: + setattr(pipe.split_gm, f"submod_{stage_idx}", model) + return pipe + return model diff --git a/train.py b/train.py index 849ae7849..cbca73b70 100644 --- a/train.py +++ b/train.py @@ -15,6 +15,8 @@ import torch import torch.nn.functional as F +from pippy.PipelineSchedule import PipelineScheduleGPipe +from pippy.PipelineStage import PipelineStage from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.tensor.parallel import loss_parallel @@ -197,9 +199,39 @@ def main(job_config: JobConfig): model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # allocate sharded model on GPU and initialize weights via DTensor - model.to_empty(device="cuda") - model.init_weights() + + # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if + # there are virtual stages + if parallel_dims.pp_enabled: + pipe_meta = model + pp_mesh = world_mesh["pp"] + pp_degree = pp_mesh.size() + pp_rank = pp_mesh.get_local_rank() + logger.info( + f"{Color.blue}Extracting pipeline module for stage {pp_rank}{Color.reset}" + ) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + + model = pipe_meta.get_stage_module(pp_rank) + stage = PipelineStage( + pipe=pipe_meta, + stage_index=pp_rank, + device=device, + group=pp_mesh.get_group(), + ) + pp_schedule = PipelineScheduleGPipe( + stage, n_microbatches=parallel_dims.pp, loss_fn=None + ) + model.to_empty(device="cuda") + else: + # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint + # and loading it to get initialization values. This is becuase the init_weights functions are written assuming + # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash + # becuase it can't find "embedding" layer, for example. + + # allocate sharded model on GPU and initialize weights via DTensor + model.to_empty(device="cuda") + model.init_weights() # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) @@ -278,39 +310,49 @@ def main(job_config: JobConfig): input_ids = input_ids.cuda() labels = labels.cuda() - + print("i", input_ids.shape) + print("l", labels.shape) optimizer.zero_grad() - # forward - pred = model(input_ids) - - with ( - loss_parallel() - if parallel_dims.loss_parallel_enabled - else contextlib.nullcontext() - ): - loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) - - # backward on scaled loss to create scaled gradients - scaler.scale(loss).backward() - - # clip gradients (after unscaling gradients of the optimizer's params) - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True - ) + if parallel_dims.pp_enabled: + if pp_mesh.get_local_rank() == 0: + pp_schedule.step(input_ids) + elif pp_mesh.get_local_rank() == pp_mesh.size() - 1: + losses = [] + pp_schedule.step(target=labels, losses=losses) + else: + schedule.step() + else: + # forward + pred = model(input_ids) + + with ( + loss_parallel() + if parallel_dims.loss_parallel_enabled + else contextlib.nullcontext() + ): + loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) + + # backward on scaled loss to create scaled gradients + scaler.scale(loss).backward() + + # clip gradients (after unscaling gradients of the optimizer's params) + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + model.parameters(), job_config.training.max_norm, foreach=True + ) - # optimizer step - # If gradients don't contain infs/NaNs, optimizer.step() is then called; - # otherwise, optimizer.step() is skipped. - scaler.step(optimizer) - scheduler.step() + # optimizer step + # If gradients don't contain infs/NaNs, optimizer.step() is then called; + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + scheduler.step() - # updates the scale for next iteration - scaler.update() + # updates the scale for next iteration + scaler.update() - current_loss = loss.item() - losses_since_last_log.append(current_loss) + current_loss = loss.item() + losses_since_last_log.append(current_loss) # log metrics if ( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index c84407cdb..c60e4e067 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -38,7 +38,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 +pipeline_parallel_degree = 2 fp8_linear = "" compile = false dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)