Skip to content

Commit

Permalink
WIP apply PP manually
Browse files Browse the repository at this point in the history
ghstack-source-id: 3f2cc8b539fd88d8ba5099f6bc026ae31ad1005a
Pull Request resolved: #308
  • Loading branch information
wconstab committed May 6, 2024
1 parent afdfbe3 commit 6797dbf
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 13 deletions.
100 changes: 98 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# llama model, i.e. activation checkpointing, etc.

from collections import defaultdict
from typing import Tuple
from typing import List, Tuple

import torch

Expand Down Expand Up @@ -138,7 +138,103 @@ def get_tp_parallel_strategy(
return RowwiseParallel, ColwiseParallel


def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig):
class DummyModule(torch.nn.Module):
def forward(self, *args):
return args


class TransformerChunk(torch.nn.Module):
def __init__(
self,
orig_model, # : Transformer,
this_stage_layer_names: List[str],
):
super().__init__()
self.tok_embeddings = None
if "tok_embeddings" in this_stage_layer_names:
self.tok_embeddings = orig_model.tok_embeddings
self.freqs_cis = orig_model.freqs_cis
# preserve FQNs of original model by preserving structure
# (including preserving position in layers[] list)- use dummy module
self.layers = orig_model.layers
for i in range(len(self.layers)):
if f"layers.{i}" not in this_stage_layer_names:
self.layers[i] = DummyModule()
self.norm = None
if "norm" in this_stage_layer_names:
self.norm = orig_model.norm
self.output = None
if "output" in this_stage_layer_names:
self.output = orig_model.output

def forward(self, input):
"""
Copypaste of original Transformer.forward, with conditionals and unpacking added
such that we handle the cases where this rank doesn't have the embedding, or doesn't have
the output layers.
"""
if self.tok_embeddings:
h = self.tok_embeddings(input)
_, seqlen, _ = h.shape
else:
h = input
_, seqlen = h.shape

freqs_cis = self.freqs_cis[:seqlen]

for layer in self.layers:
h = layer(h, freqs_cis)
output = h

if self.norm:
h = self.norm(h)
output = h

if self.output:
output = self.output(h).float()
return output


def extract_pipeline_stage_models_manual(
model, world_mesh, parallel_dims, job_config: JobConfig, device
):
"""
This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages).
The SPMD parallelisms should be applied to
"""
assert (
parallel_dims.pp_enabled
), "can't apply pipeline parallelism if it is not enabled"

pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
stage_idx = pp_rank # TODO support virtual stages
layers_per_rank = len(model.layers) // parallel_dims.pp
layer_offset = parallel_dims.pp * pp_rank
this_stage_layer_names = [
f"layers.{i + layer_offset}" for i in range(layers_per_rank)
]
if pp_rank == 0:
this_stage_layer_names.insert(0, "tok_embeddings")
elif pp_rank == pp_size - 1:
this_stage_layer_names.append("norm")
this_stage_layer_names.append("output")

stage_model = TransformerChunk(model, this_stage_layer_names)
# Create a pipeline representation from the model

# note for PipPy API
# it would be nice if we could get fx.graph out of PipeInfo and then make it possible to manually construct PipeInfo
# and then use the same _PipelineStage ctor in either tracer or graph cases.

return (stage_model,)


def apply_pipeline_parallelism_tracer(
model, world_mesh, parallel_dims, job_config: JobConfig
):
assert (
parallel_dims.pp_enabled
), "can't apply pipeline parallelism if it is not enabled"
Expand Down
46 changes: 35 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
try:
from pippy import ScheduleGPipe
from pippy.PipelineStage import _PipelineStage
from pippy import ManualPipelineStage, ScheduleGPipe

# from pippy.PipelineStage import _PipelineStage
except ImportError as exc:
raise ImportError(
"pippy is not installed. Please install it to use pipeline parallelism. "
Expand Down Expand Up @@ -224,10 +225,12 @@ def loss_fn(pred, labels):

if parallel_dims.pp_enabled:
# TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern`
from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism
from torchtitan.parallelisms.parallelize_llama import (
extract_pipeline_stage_models_manual,
)

model, pipe_info = apply_pipeline_parallelism(
model, world_mesh, parallel_dims, job_config
stage_models = extract_pipeline_stage_models_manual(
model, world_mesh, parallel_dims, job_config, device
)

# apply PT-D DP/TP parallelisms and activation checkpointing
Expand All @@ -240,13 +243,34 @@ def loss_fn(pred, labels):
# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
# there are virtual stages
if parallel_dims.pp_enabled:
stage = _PipelineStage(
stage_module=model,
stage_index=pp_rank,
pipe_info=pipe_info,
device=device,
group=pp_mesh.get_group(),
# stage = _PipelineStage(
# stage_module=model,
# stage_index=pp_rank,
# pipe_info=pipe_info,
# device=device,
# group=pp_mesh.get_group(),
# )
assert len(stage_models) == 1, "virtual stages NYI"
stage_model = stage_models[0]
chunks = parallel_dims.pp
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
stage_idx = pp_rank # TODO support virtual stages
# Get example input
input_shape = (job_config.training.batch_size, job_config.training.seq_len)
input_ids = torch.randint(
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
)
stage = ManualPipelineStage(
stage_model,
pp_rank,
pp_size,
device,
chunks,
input_args=input_ids.chunk(chunks)[0],
)
stage = stages[0]
pp_schedule = ScheduleGPipe(
stage,
n_microbatches=parallel_dims.pp,
Expand Down

0 comments on commit 6797dbf

Please sign in to comment.