From 6797dbff64a12961ae025a1da89d1a4cfab6e1e2 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 6 May 2024 12:01:45 -0700 Subject: [PATCH] WIP apply PP manually ghstack-source-id: 3f2cc8b539fd88d8ba5099f6bc026ae31ad1005a Pull Request resolved: https://github.com/pytorch/torchtitan/pull/308 --- torchtitan/parallelisms/parallelize_llama.py | 100 ++++++++++++++++++- train.py | 46 +++++++-- 2 files changed, 133 insertions(+), 13 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 27f1f28ef..9a79d5f17 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,7 +8,7 @@ # llama model, i.e. activation checkpointing, etc. from collections import defaultdict -from typing import Tuple +from typing import List, Tuple import torch @@ -138,7 +138,103 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel -def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig): +class DummyModule(torch.nn.Module): + def forward(self, *args): + return args + + +class TransformerChunk(torch.nn.Module): + def __init__( + self, + orig_model, # : Transformer, + this_stage_layer_names: List[str], + ): + super().__init__() + self.tok_embeddings = None + if "tok_embeddings" in this_stage_layer_names: + self.tok_embeddings = orig_model.tok_embeddings + self.freqs_cis = orig_model.freqs_cis + # preserve FQNs of original model by preserving structure + # (including preserving position in layers[] list)- use dummy module + self.layers = orig_model.layers + for i in range(len(self.layers)): + if f"layers.{i}" not in this_stage_layer_names: + self.layers[i] = DummyModule() + self.norm = None + if "norm" in this_stage_layer_names: + self.norm = orig_model.norm + self.output = None + if "output" in this_stage_layer_names: + self.output = orig_model.output + + def forward(self, input): + """ + Copypaste of original Transformer.forward, with conditionals and unpacking added + such that we handle the cases where this rank doesn't have the embedding, or doesn't have + the output layers. + """ + if self.tok_embeddings: + h = self.tok_embeddings(input) + _, seqlen, _ = h.shape + else: + h = input + _, seqlen = h.shape + + freqs_cis = self.freqs_cis[:seqlen] + + for layer in self.layers: + h = layer(h, freqs_cis) + output = h + + if self.norm: + h = self.norm(h) + output = h + + if self.output: + output = self.output(h).float() + return output + + +def extract_pipeline_stage_models_manual( + model, world_mesh, parallel_dims, job_config: JobConfig, device +): + """ + This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages). + + The SPMD parallelisms should be applied to + """ + assert ( + parallel_dims.pp_enabled + ), "can't apply pipeline parallelism if it is not enabled" + + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + stage_idx = pp_rank # TODO support virtual stages + layers_per_rank = len(model.layers) // parallel_dims.pp + layer_offset = parallel_dims.pp * pp_rank + this_stage_layer_names = [ + f"layers.{i + layer_offset}" for i in range(layers_per_rank) + ] + if pp_rank == 0: + this_stage_layer_names.insert(0, "tok_embeddings") + elif pp_rank == pp_size - 1: + this_stage_layer_names.append("norm") + this_stage_layer_names.append("output") + + stage_model = TransformerChunk(model, this_stage_layer_names) + # Create a pipeline representation from the model + + # note for PipPy API + # it would be nice if we could get fx.graph out of PipeInfo and then make it possible to manually construct PipeInfo + # and then use the same _PipelineStage ctor in either tracer or graph cases. + + return (stage_model,) + + +def apply_pipeline_parallelism_tracer( + model, world_mesh, parallel_dims, job_config: JobConfig +): assert ( parallel_dims.pp_enabled ), "can't apply pipeline parallelism if it is not enabled" diff --git a/train.py b/train.py index a96854829..81088a987 100644 --- a/train.py +++ b/train.py @@ -22,8 +22,9 @@ # TODO(whc) this can be removed after pippy migration into pytorch core is complete. try: - from pippy import ScheduleGPipe - from pippy.PipelineStage import _PipelineStage + from pippy import ManualPipelineStage, ScheduleGPipe + + # from pippy.PipelineStage import _PipelineStage except ImportError as exc: raise ImportError( "pippy is not installed. Please install it to use pipeline parallelism. " @@ -224,10 +225,12 @@ def loss_fn(pred, labels): if parallel_dims.pp_enabled: # TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern` - from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism + from torchtitan.parallelisms.parallelize_llama import ( + extract_pipeline_stage_models_manual, + ) - model, pipe_info = apply_pipeline_parallelism( - model, world_mesh, parallel_dims, job_config + stage_models = extract_pipeline_stage_models_manual( + model, world_mesh, parallel_dims, job_config, device ) # apply PT-D DP/TP parallelisms and activation checkpointing @@ -240,13 +243,34 @@ def loss_fn(pred, labels): # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if # there are virtual stages if parallel_dims.pp_enabled: - stage = _PipelineStage( - stage_module=model, - stage_index=pp_rank, - pipe_info=pipe_info, - device=device, - group=pp_mesh.get_group(), + # stage = _PipelineStage( + # stage_module=model, + # stage_index=pp_rank, + # pipe_info=pipe_info, + # device=device, + # group=pp_mesh.get_group(), + # ) + assert len(stage_models) == 1, "virtual stages NYI" + stage_model = stage_models[0] + chunks = parallel_dims.pp + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + stage_idx = pp_rank # TODO support virtual stages + # Get example input + input_shape = (job_config.training.batch_size, job_config.training.seq_len) + input_ids = torch.randint( + model.vocab_size, input_shape, dtype=torch.int64, device="meta" + ) + stage = ManualPipelineStage( + stage_model, + pp_rank, + pp_size, + device, + chunks, + input_args=input_ids.chunk(chunks)[0], ) + stage = stages[0] pp_schedule = ScheduleGPipe( stage, n_microbatches=parallel_dims.pp,