Skip to content

Commit

Permalink
WIP apply PP manually
Browse files Browse the repository at this point in the history
runs 1D ok now. still broken for DP+PP (haven't tried TP)

ghstack-source-id: 52da24a8b23e8631275171cf25b4e62575aced35
Pull Request resolved: #308
  • Loading branch information
wconstab committed May 7, 2024
1 parent 95fbe06 commit 99d0b08
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 17 deletions.
1 change: 1 addition & 0 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def forward(
torch.Tensor: Output tensor after attention.
"""
print(f"transformer layer got input shape {x.shape}")
bs, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

Expand Down
105 changes: 103 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# llama model, i.e. activation checkpointing, etc.

from collections import defaultdict
from typing import Tuple
from typing import List, Tuple

import torch

Expand Down Expand Up @@ -138,7 +138,108 @@ def get_tp_parallel_strategy(
return RowwiseParallel, ColwiseParallel


def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig):
class DummyTransformerLayer(torch.nn.Module):
def forward(self, input, freqs_cis):
return input


class TransformerChunk(torch.nn.Module):
def __init__(
self,
orig_model, # : Transformer,
this_stage_layer_names: List[str],
device,
):
super().__init__()
self.tok_embeddings = None
if "tok_embeddings" in this_stage_layer_names:
self.tok_embeddings = orig_model.tok_embeddings

with torch.device(device):
self.freqs_cis = orig_model._precompute_freqs_cis()

# preserve FQNs of original model by preserving structure
# (including preserving position in layers[] list)- use dummy module
self.layers = orig_model.layers
for i in range(len(self.layers)):
if f"layers.{i}" not in this_stage_layer_names:
self.layers[i] = DummyTransformerLayer()
self.norm = None
if "norm" in this_stage_layer_names:
self.norm = orig_model.norm
self.output = None
if "output" in this_stage_layer_names:
self.output = orig_model.output

def forward(self, input):
"""
Copypaste of original Transformer.forward, with conditionals and unpacking added
such that we handle the cases where this rank doesn't have the embedding, or doesn't have
the output layers.
"""
if self.tok_embeddings:
h = self.tok_embeddings(input)
else:
h = input
_, seqlen, _ = h.shape

freqs_cis = self.freqs_cis[:seqlen]

for layer in self.layers:
h = layer(h, freqs_cis)
output = h

if self.norm:
h = self.norm(h)
output = h

if self.output:
output = self.output(h).float()
return output


def extract_pipeline_stage_models_manual(
model, world_mesh, parallel_dims, job_config: JobConfig, device
):
"""
This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages).
The SPMD parallelisms should be applied to
"""
assert (
parallel_dims.pp_enabled
), "can't apply pipeline parallelism if it is not enabled"

pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
stage_idx = pp_rank # TODO support virtual stages
layers_per_rank = len(model.layers) // parallel_dims.pp
layer_offset = layers_per_rank * pp_rank
this_stage_layer_names = [
f"layers.{i + layer_offset}" for i in range(layers_per_rank)
]
if pp_rank == 0:
this_stage_layer_names.insert(0, "tok_embeddings")
assert "layers.0" in this_stage_layer_names
elif pp_rank == pp_size - 1:
this_stage_layer_names.append("norm")
this_stage_layer_names.append("output")
assert "layers.1" in this_stage_layer_names

stage_model = TransformerChunk(model, this_stage_layer_names, device)
# Create a pipeline representation from the model

# note for PipPy API
# it would be nice if we could get fx.graph out of PipeInfo and then make it possible to manually construct PipeInfo
# and then use the same _PipelineStage ctor in either tracer or graph cases.

return (stage_model,)


def apply_pipeline_parallelism_tracer(
model, world_mesh, parallel_dims, job_config: JobConfig
):
assert (
parallel_dims.pp_enabled
), "can't apply pipeline parallelism if it is not enabled"
Expand Down
71 changes: 56 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
try:
from pippy import ScheduleGPipe
from pippy.PipelineStage import _PipelineStage
from pippy import ManualPipelineStage, ScheduleGPipe

# from pippy.PipelineStage import _PipelineStage
except ImportError as exc:
raise ImportError(
"pippy is not installed. Please install it to use pipeline parallelism. "
Expand Down Expand Up @@ -224,28 +225,68 @@ def loss_fn(pred, labels):

if parallel_dims.pp_enabled:
# TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern`
from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism
from torchtitan.parallelisms.parallelize_llama import (
extract_pipeline_stage_models_manual,
)

model, pipe_info = apply_pipeline_parallelism(
model, world_mesh, parallel_dims, job_config
stage_models = extract_pipeline_stage_models_manual(
model, world_mesh, parallel_dims, job_config, device
)
stage_models = [
models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
for model in stage_models
]

# apply PT-D DP/TP parallelisms and activation checkpointing
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
else:
# apply PT-D DP/TP parallelisms and activation checkpointing
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)

model.to_empty(device="cuda")

# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
# there are virtual stages
if parallel_dims.pp_enabled:
stage = _PipelineStage(
stage_module=model,
stage_index=pp_rank,
pipe_info=pipe_info,
device=device,
group=pp_mesh.get_group(),
# stage = _PipelineStage(
# stage_module=model,
# stage_index=pp_rank,
# pipe_info=pipe_info,
# device=device,
# group=pp_mesh.get_group(),
# )
assert len(stage_models) == 1, "virtual stages NYI"
stage_model = stage_models[0]
chunks = parallel_dims.pp
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
stage_idx = pp_rank # TODO support virtual stages
# Get example input
if pp_rank == 0:
input_shape = (job_config.training.batch_size, job_config.training.seq_len)
input_ids = torch.randint(
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
)
else:
input_shape = (
job_config.training.batch_size,
job_config.training.seq_len,
model_config.dim,
)
input_ids = torch.randint(
model.vocab_size, input_shape, dtype=torch.float32, device="meta"
)
stage = ManualPipelineStage(
stage_model,
pp_rank,
pp_size,
device,
chunks,
input_args=input_ids.chunk(chunks)[0],
group=pp_mesh.get_group("pp"),
)
pp_schedule = ScheduleGPipe(
stage,
Expand Down

0 comments on commit 99d0b08

Please sign in to comment.