From 4134d0811cd5296e886fc6911e5a49613054d7e9 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 5 Apr 2024 09:39:55 -0700 Subject: [PATCH] WIP integrate pippy's tracer frontend Loss now runs and propagates to logger, but optimizer isn't working ghstack-source-id: 4ede08f5a9d1bc994448cb057bb491d24866d078 Pull Request resolved: https://github.com/pytorch/torchtrain/pull/161 --- run_llama_train.sh | 2 +- seed_checkpoint.py | 2 +- torchtrain/parallelisms/parallelize_llama.py | 32 ++++++++- train.py | 72 +++++++++++++++++--- train_configs/debug_model.toml | 6 +- 5 files changed, 97 insertions(+), 17 deletions(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index a46cf5cf3..0dcc5c277 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain} # e.g. # LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh -NGPU=${NGPU:-"8"} +NGPU=${NGPU:-"2"} # by default log just rank 0 output, LOG_RANK=${LOG_RANK:-0} diff --git a/seed_checkpoint.py b/seed_checkpoint.py index 55501b521..949e70e1d 100644 --- a/seed_checkpoint.py +++ b/seed_checkpoint.py @@ -36,7 +36,7 @@ def main(job_config: JobConfig): if job_config.training.fp8_linear: build_fp8_linear(model, job_config) - model.reset_parameters() + model.init_weights() checkpoint_id = os.path.join(job_config.training.checkpoint_folder, "step-0") logger.info(f"Creating seed (step-0) checkpoint in {checkpoint_id}") diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index e64267c50..53292c405 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -8,7 +8,7 @@ from typing import Tuple import torch - +from pippy import annotate_split_points, Pipe, PipeSplitWrapper from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -134,7 +134,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): the model must fit on GPU or CPU memory. """ if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") + pp_mesh = world_mesh["pp"] + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + for i in range(1, parallel_dims.pp): + annotate_split_points( + model, + { + f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING + }, + ) + + # Get example input + label_shape = input_shape = (8, 2048) # TODO + input_ids = torch.randint( + model.vocab_size, input_shape, dtype=torch.int64, device="meta" + ) + labels = torch.randint( + model.vocab_size, label_shape, dtype=torch.int64, device="meta" + ) + print("input_ids: ", input_ids.shape, input_ids.dtype) + print("labels: ", labels.shape, labels.dtype) + + # Create a pipeline representation from the model + pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,)) + model = pipe.get_stage_module(stage_idx) if parallel_dims.tp_enabled: tp_mesh = world_mesh["tp"] @@ -233,4 +257,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): logger.info(f"Applied {ac_mode} activation checkpointing to the model") logger.info("Applied FSDP to the model") + if parallel_dims.pp_enabled: + setattr(pipe.split_gm, f"submod_{stage_idx}", model) + return pipe + return model diff --git a/train.py b/train.py index 73a882479..2050c87c0 100644 --- a/train.py +++ b/train.py @@ -15,6 +15,8 @@ import torch import torch.nn.functional as F +from pippy.PipelineSchedule import PipelineScheduleGPipe +from pippy.PipelineStage import PipelineStage from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel @@ -120,7 +122,9 @@ def main(job_config: JobConfig): world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + # torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + torch.cuda.set_device(device) init_distributed(job_config) world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -139,6 +143,14 @@ def main(job_config: JobConfig): dp_rank = dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + pp_degree = pp_mesh.size() + pp_rank = pp_mesh.get_local_rank() + else: + pp_degree, pp_rank = 1, 0 + data_loader = build_dataloader_fn( job_config.training.dataset, job_config.training.dataset_path, @@ -197,14 +209,38 @@ def loss_fn(pred, labels): model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # allocate sharded model on GPU and initialize weights via DTensor + if parallel_dims.pp_enabled: + pipe_meta = model + model = pipe_meta.get_stage_module(pp_rank) + model.to_empty(device="cuda") - model.init_weights() + + # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if + # there are virtual stages + if parallel_dims.pp_enabled: + stage = PipelineStage( + pipe=pipe_meta, + stage_index=pp_rank, + device=device, + group=pp_mesh.get_group(), + ) + pp_schedule = PipelineScheduleGPipe( + stage, + n_microbatches=parallel_dims.pp, + loss_fn=loss_fn, + ) + else: + # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint + # and loading it to get initialization values. This is becuase the init_weights functions are written assuming + # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash + # becuase it can't find "embedding" layer, for example. + + # allocate sharded model on GPU and initialize weights via DTensor + model.init_weights() # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) - metric_logger = build_metric_logger(job_config) # torch.compile model for improved performance @@ -274,13 +310,30 @@ def loss_fn(pred, labels): input_ids = input_ids.cuda() labels = labels.cuda() - optimizer.zero_grad() - # forward / backward - pred = model(input_ids) - loss = loss_fn(pred, labels) - loss.backward() + if parallel_dims.pp_enabled: + # pipeline parallel forward / backward inside step() call + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + + if pp_mesh.get_local_rank() == 0: + pp_schedule.step(input_ids) + elif is_last_stage: + losses = [] + pp_schedule.step(target=labels, losses=losses) + else: + schedule.step() + + # accumulate losses across pipeline microbatches + current_loss = ( + torch.mean(torch.stack(losses)).item() if is_last_stage else -1.0 + ) + else: + # forward / backward + pred = model(input_ids) + loss = loss_fn(pred, labels) + loss.backward() + current_loss = loss.item() # clip gradients torch.nn.utils.clip_grad_norm_( @@ -291,7 +344,6 @@ def loss_fn(pred, labels): optimizer.step() scheduler.step() - current_loss = loss.item() losses_since_last_log.append(current_loss) # log metrics diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index c84407cdb..b0f9fead1 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -16,9 +16,9 @@ enable_tensorboard = true save_tb_folder = "tb" [checkpoint] -interval = 3600 +interval = 10 interval_type = "steps" -folder = "" +folder = "ckpt" [model] name = "llama" @@ -38,7 +38,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 +pipeline_parallel_degree = 2 fp8_linear = "" compile = false dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)