diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index d69dad67..2c050c45 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -182,6 +182,7 @@ def forward( torch.Tensor: Output tensor after attention. """ + print(f"transformer layer got input shape {x.shape}") bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 27f1f28e..af2838d1 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,7 +8,7 @@ # llama model, i.e. activation checkpointing, etc. from collections import defaultdict -from typing import Tuple +from typing import List, Tuple import torch @@ -138,7 +138,108 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel -def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig): +class DummyTransformerLayer(torch.nn.Module): + def forward(self, input, freqs_cis): + return input + + +class TransformerChunk(torch.nn.Module): + def __init__( + self, + orig_model, # : Transformer, + this_stage_layer_names: List[str], + device, + ): + super().__init__() + self.tok_embeddings = None + if "tok_embeddings" in this_stage_layer_names: + self.tok_embeddings = orig_model.tok_embeddings + + with torch.device(device): + self.freqs_cis = orig_model._precompute_freqs_cis() + + # preserve FQNs of original model by preserving structure + # (including preserving position in layers[] list)- use dummy module + self.layers = orig_model.layers + for i in range(len(self.layers)): + if f"layers.{i}" not in this_stage_layer_names: + self.layers[i] = DummyTransformerLayer() + self.norm = None + if "norm" in this_stage_layer_names: + self.norm = orig_model.norm + self.output = None + if "output" in this_stage_layer_names: + self.output = orig_model.output + + def forward(self, input): + """ + Copypaste of original Transformer.forward, with conditionals and unpacking added + such that we handle the cases where this rank doesn't have the embedding, or doesn't have + the output layers. + """ + if self.tok_embeddings: + h = self.tok_embeddings(input) + else: + h = input + _, seqlen, _ = h.shape + + freqs_cis = self.freqs_cis[:seqlen] + + for layer in self.layers: + h = layer(h, freqs_cis) + output = h + + if self.norm: + h = self.norm(h) + output = h + + if self.output: + output = self.output(h).float() + return output + + +def extract_pipeline_stage_models_manual( + model, world_mesh, parallel_dims, job_config: JobConfig, device +): + """ + This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages). + + The SPMD parallelisms should be applied to + """ + assert ( + parallel_dims.pp_enabled + ), "can't apply pipeline parallelism if it is not enabled" + + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + stage_idx = pp_rank # TODO support virtual stages + layers_per_rank = len(model.layers) // parallel_dims.pp + layer_offset = layers_per_rank * pp_rank + this_stage_layer_names = [ + f"layers.{i + layer_offset}" for i in range(layers_per_rank) + ] + if pp_rank == 0: + this_stage_layer_names.insert(0, "tok_embeddings") + assert "layers.0" in this_stage_layer_names + elif pp_rank == pp_size - 1: + this_stage_layer_names.append("norm") + this_stage_layer_names.append("output") + assert "layers.1" in this_stage_layer_names + + stage_model = TransformerChunk(model, this_stage_layer_names, device) + # Create a pipeline representation from the model + + # note for PipPy API + # it would be nice if we could get fx.graph out of PipeInfo and then make it possible to manually construct PipeInfo + # and then use the same _PipelineStage ctor in either tracer or graph cases. + + return (stage_model,) + + +def apply_pipeline_parallelism_tracer( + model, world_mesh, parallel_dims, job_config: JobConfig +): assert ( parallel_dims.pp_enabled ), "can't apply pipeline parallelism if it is not enabled" diff --git a/train.py b/train.py index a9685482..3e5b98a7 100644 --- a/train.py +++ b/train.py @@ -22,8 +22,9 @@ # TODO(whc) this can be removed after pippy migration into pytorch core is complete. try: - from pippy import ScheduleGPipe - from pippy.PipelineStage import _PipelineStage + from pippy import ManualPipelineStage, ScheduleGPipe + + # from pippy.PipelineStage import _PipelineStage except ImportError as exc: raise ImportError( "pippy is not installed. Please install it to use pipeline parallelism. " @@ -224,28 +225,68 @@ def loss_fn(pred, labels): if parallel_dims.pp_enabled: # TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern` - from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism + from torchtitan.parallelisms.parallelize_llama import ( + extract_pipeline_stage_models_manual, + ) - model, pipe_info = apply_pipeline_parallelism( - model, world_mesh, parallel_dims, job_config + stage_models = extract_pipeline_stage_models_manual( + model, world_mesh, parallel_dims, job_config, device ) + stage_models = [ + models_parallelize_fns[model_name]( + model, world_mesh, parallel_dims, job_config + ) + for model in stage_models + ] - # apply PT-D DP/TP parallelisms and activation checkpointing - model = models_parallelize_fns[model_name]( - model, world_mesh, parallel_dims, job_config - ) + else: + # apply PT-D DP/TP parallelisms and activation checkpointing + model = models_parallelize_fns[model_name]( + model, world_mesh, parallel_dims, job_config + ) model.to_empty(device="cuda") # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if # there are virtual stages if parallel_dims.pp_enabled: - stage = _PipelineStage( - stage_module=model, - stage_index=pp_rank, - pipe_info=pipe_info, - device=device, - group=pp_mesh.get_group(), + # stage = _PipelineStage( + # stage_module=model, + # stage_index=pp_rank, + # pipe_info=pipe_info, + # device=device, + # group=pp_mesh.get_group(), + # ) + assert len(stage_models) == 1, "virtual stages NYI" + stage_model = stage_models[0] + chunks = parallel_dims.pp + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + stage_idx = pp_rank # TODO support virtual stages + # Get example input + if pp_rank == 0: + input_shape = (job_config.training.batch_size, job_config.training.seq_len) + input_ids = torch.randint( + model.vocab_size, input_shape, dtype=torch.int64, device="meta" + ) + else: + input_shape = ( + job_config.training.batch_size, + job_config.training.seq_len, + model_config.dim, + ) + input_ids = torch.randint( + model.vocab_size, input_shape, dtype=torch.float32, device="meta" + ) + stage = ManualPipelineStage( + stage_model, + pp_rank, + pp_size, + device, + chunks, + input_args=input_ids.chunk(chunks)[0], + group=pp_mesh.get_group("pp"), ) pp_schedule = ScheduleGPipe( stage,