Skip to content

Commit

Permalink
Add Pipeline Parallel (and 2D PP+FSDP) support
Browse files Browse the repository at this point in the history
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: cbbb628fd823d579064a8038e6511ec77457ef19
Pull Request resolved: #161
  • Loading branch information
wconstab committed May 2, 2024
1 parent d293e5e commit 5535c3b
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 23 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unit_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install -r requirements.txt
python -m pip install -r dev-requirements.txt
python -m pip install git+https://github.com/pytorch/pippy
python -m pip install -e .
- name: Run test_runner.py
run: python ./test_runner.py
Expand Down
52 changes: 45 additions & 7 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class OverrideDefinitions:

override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
test_descr: str = "default"
requires_seed_checkpoint: bool = False


CONFIG_DIR = "./train_configs"
Expand Down Expand Up @@ -85,25 +86,62 @@ class OverrideDefinitions:
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--training.pipeline_parallel_degree 4",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP 1D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--training.pipeline_parallel_degree 2",
"--training.data_parallel_degree 2",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP+DP 2D test",
requires_seed_checkpoint=True,
),
]


def _run_cmd(cmd):
return subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)


def run_test(test_flavor: OverrideDefinitions, full_path: str):
# run_test supports sequence of tests.
for override_arg in test_flavor.override_args:

cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"
if override_arg:
cmd += " " + " ".join(override_arg)
print(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)

if test_flavor.requires_seed_checkpoint:
print("Creating seed checkpoint")
result = _run_cmd(
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh --checkpoint.folder {test_checkpoint_dir}"
)
print(result.stdout)

result = _run_cmd(cmd)
print(result.stdout)
if result.returncode != 0:
raise Exception(
Expand Down
58 changes: 53 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@

import torch

# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
try:
from pippy import pipeline, SplitPoint
except ImportError as exc:
raise ImportError(
"pippy is not installed. Please install it to use pipeline parallelism. "
"`pip install git+https://github.com/pytorch/pippy`"
) from exc

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -137,7 +146,34 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
the model must fit on GPU or CPU memory.
"""
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")

if job_config.model.norm_type == "fused_rmsnorm":
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
raise NotImplementedError(
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
)
pp_mesh = world_mesh["pp"]
stage_idx = pp_mesh.get_local_rank()
layers_per_rank = len(model.layers) // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
}
# Get example input
label_shape = input_shape = (8, 2048) # TODO
input_ids = torch.randint(
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
)
labels = torch.randint(
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
)

# Create a pipeline representation from the model
pipe = pipeline(
model, parallel_dims.pp, example_args=(input_ids,), split_spec=split_spec
)
model = pipe.get_stage_module(stage_idx)

if parallel_dims.tp_enabled:
if job_config.model.norm_type == "fused_rmsnorm":
Expand Down Expand Up @@ -215,27 +251,39 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
# TODO(whc) need to fix PP + FSDP-mixed-precision
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
for layer_name, transformer_block in model.layers.named_children():
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
# reshard_after_forward = layer_id < len(model.layers) - 1
# TODO(whc) need to fix correctly handle layer-ids on pp-split module
reshard_after_forward = True
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
# model.layers[layer_id] = transformer_block
# TODO(whc)
setattr(model.layers, layer_name, transformer_block)
model = fully_shard(model, **fsdp_config)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")

if parallel_dims.pp_enabled:
setattr(pipe.split_gm, f"submod_{stage_idx}", model)
return pipe

return model
91 changes: 80 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@

import torch
import torch.nn.functional as F

# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
try:
from pippy import PipelineStage, ScheduleGPipe
except ImportError as exc:
raise ImportError(
"pippy is not installed. Please install it to use pipeline parallelism. "
"`pip install git+https://github.com/pytorch/pippy`"
) from exc

from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
Expand Down Expand Up @@ -126,7 +136,8 @@ def main(job_config: JobConfig):
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
init_distributed(job_config)

world_mesh = parallel_dims.build_mesh(device_type="cuda")
Expand All @@ -144,6 +155,15 @@ def main(job_config: JobConfig):
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0

if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]
pp_degree = pp_mesh.size()
pp_rank = pp_mesh.get_local_rank()

else:
pp_degree, pp_rank = 1, 0

data_loader = build_hf_data_loader(
job_config.training.dataset,
job_config.training.dataset_path,
Expand Down Expand Up @@ -205,9 +225,34 @@ def loss_fn(pred, labels):
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
# allocate sharded model on GPU and initialize weights via DTensor
if parallel_dims.pp_enabled:
pipe_meta = model
model = pipe_meta.get_stage_module(pp_rank)

model.to_empty(device="cuda")
model.init_weights()

# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
# there are virtual stages
if parallel_dims.pp_enabled:
stage = PipelineStage(
pipe=pipe_meta,
stage_index=pp_rank,
device=device,
group=pp_mesh.get_group(),
)
pp_schedule = ScheduleGPipe(
stage,
n_microbatches=parallel_dims.pp,
loss_fn=loss_fn,
)
else:
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
# becuase it can't find "embedding" layer, for example.

# allocate sharded model on GPU and initialize weights via DTensor
model.init_weights()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
Expand All @@ -219,7 +264,6 @@ def loss_fn(pred, labels):
# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
Expand Down Expand Up @@ -255,7 +299,13 @@ def loss_fn(pred, labels):
logger.info("Created seed checkpoint")
return

checkpoint.load()
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
raise RuntimeError(
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
)

# plot losses loaded from checkpoint (if any) to TensorBoard
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
Expand Down Expand Up @@ -295,14 +345,33 @@ def loss_fn(pred, labels):

input_ids = input_ids.cuda()
labels = labels.cuda()

optimizer.zero_grad()

# forward / backward
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with loss_parallel_ctx():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
schedule.step()

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
else:
# Non-PP forward / backward
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
Expand Down

0 comments on commit 5535c3b

Please sign in to comment.