Skip to content

Commit

Permalink
Add Pipeline Parallel (and 2D PP+FSDP) support
Browse files Browse the repository at this point in the history
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: 0616a1c0d40f8e51ddfc1b2d330dbddc491e00e2
Pull Request resolved: #161
  • Loading branch information
wconstab committed May 7, 2024
1 parent f72a2a0 commit 95fbe06
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 27 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unit_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install -r requirements.txt
python -m pip install -r dev-requirements.txt
python -m pip install git+https://github.com/pytorch/pippy
- name: Run test_runner.py
run: python ./test_runner.py
- name: Upload Coverage to Codecov
Expand Down
65 changes: 57 additions & 8 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class OverrideDefinitions:

override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
test_descr: str = "default"
requires_seed_checkpoint: bool = False
ngpu: int = 4


CONFIG_DIR = "./train_configs"
Expand Down Expand Up @@ -85,25 +87,72 @@ class OverrideDefinitions:
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_pp",
"--training.pipeline_parallel_degree 2",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP 1D test",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_pp_dp",
"--training.pipeline_parallel_degree 2",
"--training.data_parallel_degree 2",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP+DP 2D test",
requires_seed_checkpoint=True,
),
]


def _run_cmd(cmd):
return subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)


def run_test(test_flavor: OverrideDefinitions, full_path: str):
# run_test supports sequence of tests.
for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
if override_arg:
cmd += " " + " ".join(override_arg)
print(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)

if test_flavor.requires_seed_checkpoint:
checkpoint_folder_arg = None
for arg in override_arg:
if "--checkpoint.folder" in arg:
checkpoint_folder_arg = arg
assert (
checkpoint_folder_arg is not None
), "Can't use seed checkpoint if folder is not specified"
print("Creating seed checkpoint")
result = _run_cmd(
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {checkpoint_folder_arg}"
)
print(result.stdout)

result = _run_cmd(cmd)
print(result.stdout)
if result.returncode != 0:
raise Exception(
Expand Down
60 changes: 53 additions & 7 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@

import torch

# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
try:
from pippy import pipeline, SplitPoint
except ImportError as exc:
raise ImportError(
"pippy is not installed. Please install it to use pipeline parallelism. "
"`pip install git+https://github.com/pytorch/pippy`"
) from exc

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -129,15 +138,45 @@ def get_tp_parallel_strategy(
return RowwiseParallel, ColwiseParallel


def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig):
assert (
parallel_dims.pp_enabled
), "can't apply pipeline parallelism if it is not enabled"

if job_config.model.norm_type == "fused_rmsnorm":
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
raise NotImplementedError(
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
)
pp_mesh = world_mesh["pp"]
stage_idx = pp_mesh.get_local_rank()
layers_per_rank = len(model.layers) // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
}
# Get example input
input_shape = (job_config.training.batch_size, job_config.training.seq_len)
input_ids = torch.randint(
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
)

# Create a pipeline representation from the model
pipe = pipeline(
model, parallel_dims.pp, example_args=(input_ids,), split_spec=split_spec
)
model = pipe.get_stage_module(stage_idx)
return model, pipe.pipe_info


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms and activation checkpointing to the model.
Apply SPMD parallelisms and activation checkpointing to the model.
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")

if parallel_dims.tp_enabled:
if job_config.model.norm_type == "fused_rmsnorm":
Expand Down Expand Up @@ -215,24 +254,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
# TODO(whc) need to fix PP + FSDP-mixed-precision
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
for layer_name, transformer_block in model.layers.named_children():
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
# reshard_after_forward = layer_id < len(model.layers) - 1
# TODO(whc) need to fix correctly handle layer-ids on pp-split module
reshard_after_forward = True
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model.layers.add_module(layer_name, transformer_block)

model = fully_shard(model, **fsdp_config)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
Expand Down
100 changes: 88 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@

import torch
import torch.nn.functional as F

# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
try:
from pippy import ScheduleGPipe
from pippy.PipelineStage import _PipelineStage
except ImportError as exc:
raise ImportError(
"pippy is not installed. Please install it to use pipeline parallelism. "
"`pip install git+https://github.com/pytorch/pippy`"
) from exc

from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
Expand Down Expand Up @@ -126,7 +137,8 @@ def main(job_config: JobConfig):
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
init_distributed(job_config)

world_mesh = parallel_dims.build_mesh(device_type="cuda")
Expand All @@ -144,6 +156,15 @@ def main(job_config: JobConfig):
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0

if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]
pp_degree = pp_mesh.size()
pp_rank = pp_mesh.get_local_rank()

else:
pp_degree, pp_rank = 1, 0

data_loader = build_hf_data_loader(
job_config.training.dataset,
job_config.training.dataset_path,
Expand Down Expand Up @@ -201,13 +222,44 @@ def loss_fn(pred, labels):
# obtain the peak flops of bf16 type for MFU calculation
gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name)

# apply PT-D parallelisms and activation checkpointing
if parallel_dims.pp_enabled:
# TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern`
from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism

model, pipe_info = apply_pipeline_parallelism(
model, world_mesh, parallel_dims, job_config
)

# apply PT-D DP/TP parallelisms and activation checkpointing
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
# allocate sharded model on GPU and initialize weights via DTensor

model.to_empty(device="cuda")
model.init_weights()

# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
# there are virtual stages
if parallel_dims.pp_enabled:
stage = _PipelineStage(
stage_module=model,
stage_index=pp_rank,
pipe_info=pipe_info,
device=device,
group=pp_mesh.get_group(),
)
pp_schedule = ScheduleGPipe(
stage,
n_microbatches=parallel_dims.pp,
loss_fn=loss_fn,
)
else:
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
# becuase it can't find "embedding" layer, for example.

# allocate sharded model on GPU and initialize weights via DTensor
model.init_weights()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
Expand All @@ -219,7 +271,6 @@ def loss_fn(pred, labels):
# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
Expand Down Expand Up @@ -257,7 +308,13 @@ def loss_fn(pred, labels):
logger.info("Created seed checkpoint")
return

checkpoint.load()
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
raise RuntimeError(
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
)

# plot losses loaded from checkpoint (if any) to TensorBoard
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
Expand Down Expand Up @@ -299,14 +356,33 @@ def loss_fn(pred, labels):

input_ids = input_ids.cuda()
labels = labels.cuda()

optimizer.zero_grad()

# forward / backward
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with loss_parallel_ctx():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
schedule.step()

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
else:
# Non-PP forward / backward
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
Expand Down

0 comments on commit 95fbe06

Please sign in to comment.