-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Pipeline Parallel (and 2D PP+FSDP) support #161
Changes from 31 commits
36d2293
8208b36
4bb6409
5fb7d12
d437461
522f93b
093bb94
9628813
57ecb37
edcf002
1707df9
a739c49
c4bd26b
72469d4
a2697f2
f2de4f3
0c1858d
9571936
a12f524
413fdc8
918265e
c3ccf04
07ae194
c7e1a7d
6adc6dc
f597ef2
a01b2a3
4c336cc
7887b25
2bc6a9a
a43fa7f
6558428
2def4ff
41b9928
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,15 @@ | |
|
||
import torch | ||
|
||
# TODO(whc) this can be removed after pippy migration into pytorch core is complete. | ||
try: | ||
from pippy import pipeline, SplitPoint | ||
except ImportError as exc: | ||
raise ImportError( | ||
"pippy is not installed. Please install it to use pipeline parallelism. " | ||
"`pip install git+https://github.com/pytorch/pippy`" | ||
) from exc | ||
|
||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy | ||
from torch.distributed._tensor import Replicate, Shard | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
|
@@ -129,15 +138,48 @@ def get_tp_parallel_strategy( | |
return RowwiseParallel, ColwiseParallel | ||
|
||
|
||
def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig): | ||
assert ( | ||
parallel_dims.pp_enabled | ||
), "can't apply pipeline parallelism if it is not enabled" | ||
|
||
if job_config.model.norm_type == "fused_rmsnorm": | ||
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode | ||
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm | ||
raise NotImplementedError( | ||
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." | ||
) | ||
pp_mesh = world_mesh["pp"] | ||
stage_idx = pp_mesh.get_local_rank() | ||
layers_per_rank = len(model.layers) // parallel_dims.pp | ||
split_spec = { | ||
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING | ||
for i in range(1, parallel_dims.pp) | ||
} | ||
# Get example input | ||
label_shape = input_shape = (8, 2048) # TODO | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kwen2501 any ideas for a clean way to do this in torchtrain? do we expect people to get a batch out of their dataloader and then reset it? or do we expect people to hardcode it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think what i might do is directly pass input_shape from train.py, and in train.py i can set input_shape = (job_config.batch_size, job_config.seq_len) or something. is that clean enough? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok pushed a variation on this. not sure if its better to hide this inside parallelize since we already have job config, or make it explicit from train.py that we are passing input_shape in for some reason There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either way sounds okay to me -- eventually, the shape comes the config. |
||
input_ids = torch.randint( | ||
model.vocab_size, input_shape, dtype=torch.int64, device="meta" | ||
) | ||
labels = torch.randint( | ||
model.vocab_size, label_shape, dtype=torch.int64, device="meta" | ||
) | ||
|
||
# Create a pipeline representation from the model | ||
pipe = pipeline( | ||
model, parallel_dims.pp, example_args=(input_ids,), split_spec=split_spec | ||
) | ||
model = pipe.get_stage_module(stage_idx) | ||
return model, pipe.pipe_info | ||
|
||
|
||
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | ||
""" | ||
Apply parallelisms and activation checkpointing to the model. | ||
Apply SPMD parallelisms and activation checkpointing to the model. | ||
|
||
NOTE: The passed-in model preferably should be on meta device. Otherwise, | ||
the model must fit on GPU or CPU memory. | ||
""" | ||
if parallel_dims.pp_enabled: | ||
raise NotImplementedError("PP not implemented yet.") | ||
|
||
if parallel_dims.tp_enabled: | ||
if job_config.model.norm_type == "fused_rmsnorm": | ||
|
@@ -215,24 +257,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names | ||
# TODO: Expose `reduce_dtype` as a config option. | ||
mp_policy = MixedPrecisionPolicy( | ||
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 | ||
# TODO(whc) need to fix PP + FSDP-mixed-precision | ||
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs | ||
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32 | ||
param_dtype=torch.float32, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we shouldn't by default change this, this would make the cases where FSDP or FSDP + TP use fp32 instead of bf16 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if supporting bf16 should be a criteria for landing. I would imagine that training with FSDP + PP in fp32 is not really viable efficiency-wise (at least for larger jobs). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should fix this before landing the PP change. I think there was a possible way to fix this in the tracer, but lost track of it, will dig it up |
||
reduce_dtype=torch.float32, | ||
) | ||
ac_mode = job_config.activation_checkpoint.mode | ||
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
for layer_name, transformer_block in model.layers.named_children(): | ||
if job_config.activation_checkpoint.mode in ("full", "selective"): | ||
transformer_block = checkpoint_wrapper( | ||
transformer_block, job_config.activation_checkpoint | ||
) | ||
# As an optimization, do not reshard after forward for the last | ||
# transformer block since FSDP would prefetch it immediately | ||
reshard_after_forward = layer_id < len(model.layers) - 1 | ||
# reshard_after_forward = layer_id < len(model.layers) - 1 | ||
# TODO(whc) need to fix correctly handle layer-ids on pp-split module | ||
reshard_after_forward = True | ||
fully_shard( | ||
transformer_block, | ||
**fsdp_config, | ||
reshard_after_forward=reshard_after_forward, | ||
) | ||
model.layers[layer_id] = transformer_block | ||
model.layers.add_module(layer_name, transformer_block) | ||
|
||
model = fully_shard(model, **fsdp_config) | ||
if ac_mode in ("full", "selective"): | ||
logger.info(f"Applied {ac_mode} activation checkpointing to the model") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,17 @@ | |
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
# TODO(whc) this can be removed after pippy migration into pytorch core is complete. | ||
try: | ||
from pippy import ScheduleGPipe | ||
from pippy.PipelineStage import _PipelineStage | ||
except ImportError as exc: | ||
raise ImportError( | ||
"pippy is not installed. Please install it to use pipeline parallelism. " | ||
"`pip install git+https://github.com/pytorch/pippy`" | ||
) from exc | ||
|
||
from torch.distributed import destroy_process_group | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.distributed.elastic.multiprocessing.errors import record | ||
|
@@ -126,7 +137,8 @@ def main(job_config: JobConfig): | |
world_size=world_size, | ||
enable_loss_parallel=job_config.training.enable_loss_parallel, | ||
) | ||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | ||
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") | ||
torch.cuda.set_device(device) | ||
init_distributed(job_config) | ||
|
||
world_mesh = parallel_dims.build_mesh(device_type="cuda") | ||
|
@@ -144,6 +156,15 @@ def main(job_config: JobConfig): | |
dp_rank = dp_mesh.get_local_rank() | ||
else: | ||
dp_degree, dp_rank = 1, 0 | ||
|
||
if parallel_dims.pp_enabled: | ||
pp_mesh = world_mesh["pp"] | ||
pp_degree = pp_mesh.size() | ||
pp_rank = pp_mesh.get_local_rank() | ||
|
||
else: | ||
pp_degree, pp_rank = 1, 0 | ||
|
||
data_loader = build_hf_data_loader( | ||
job_config.training.dataset, | ||
job_config.training.dataset_path, | ||
|
@@ -201,13 +222,44 @@ def loss_fn(pred, labels): | |
# obtain the peak flops of bf16 type for MFU calculation | ||
gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) | ||
|
||
# apply PT-D parallelisms and activation checkpointing | ||
if parallel_dims.pp_enabled: | ||
# TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern` | ||
from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism | ||
|
||
model, pipe_info = apply_pipeline_parallelism( | ||
model, world_mesh, parallel_dims, job_config | ||
) | ||
|
||
# apply PT-D DP/TP parallelisms and activation checkpointing | ||
model = models_parallelize_fns[model_name]( | ||
model, world_mesh, parallel_dims, job_config | ||
) | ||
# allocate sharded model on GPU and initialize weights via DTensor | ||
|
||
model.to_empty(device="cuda") | ||
model.init_weights() | ||
|
||
# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if | ||
# there are virtual stages | ||
if parallel_dims.pp_enabled: | ||
stage = _PipelineStage( | ||
stage_module=model, | ||
stage_index=pp_rank, | ||
pipe_info=pipe_info, | ||
device=device, | ||
group=pp_mesh.get_group(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wondering if we should put the stage creation into parallelize_llama, IMO we only need pp_schedule in train.py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, I think this question and Ke's suggestion about returning a PipelineStage from parallelize_llama are better taken in context of a next PR that also adds support for looped schedules. Looped schedules further complicate things bc the PP logic first needs to chunk up the model, then apply the DP/TP portion of parallelize_llama on each chunk, and finally pass all the chunks into the schedule. I think in the end, I might prefer to separate out PP from parallelize_llama, and have a flow where we can take the return from PP apply function and iteratively call parallelize_llama on those chunks. |
||
) | ||
pp_schedule = ScheduleGPipe( | ||
stage, | ||
n_microbatches=parallel_dims.pp, | ||
loss_fn=loss_fn, | ||
) | ||
else: | ||
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint | ||
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming | ||
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash | ||
# becuase it can't find "embedding" layer, for example. | ||
|
||
# allocate sharded model on GPU and initialize weights via DTensor | ||
model.init_weights() | ||
|
||
gpu_mem_stats = gpu_memory_monitor.get_peak_stats() | ||
logger.info( | ||
|
@@ -219,7 +271,6 @@ def loss_fn(pred, labels): | |
# build optimizer after applying parallelisms to the model | ||
optimizer = build_optimizer(model, job_config) | ||
scheduler = get_lr_scheduler(optimizer, job_config) | ||
|
||
metric_logger = build_metric_logger(job_config) | ||
|
||
# torch.compile model for improved performance | ||
|
@@ -257,7 +308,13 @@ def loss_fn(pred, labels): | |
logger.info("Created seed checkpoint") | ||
return | ||
|
||
checkpoint.load() | ||
checkpoint_loaded = checkpoint.load() | ||
|
||
if parallel_dims.pp_enabled and not checkpoint_loaded: | ||
raise RuntimeError( | ||
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " | ||
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" | ||
) | ||
|
||
# plot losses loaded from checkpoint (if any) to TensorBoard | ||
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted. | ||
|
@@ -299,14 +356,33 @@ def loss_fn(pred, labels): | |
|
||
input_ids = input_ids.cuda() | ||
labels = labels.cuda() | ||
|
||
optimizer.zero_grad() | ||
|
||
# forward / backward | ||
with loss_parallel_ctx(): | ||
pred = model(input_ids) | ||
loss = loss_fn(pred, labels) | ||
loss.backward() | ||
if parallel_dims.pp_enabled: | ||
# pipeline parallel forward / backward inside step() call | ||
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 | ||
|
||
with loss_parallel_ctx(): | ||
if pp_mesh.get_local_rank() == 0: | ||
pp_schedule.step(input_ids) | ||
elif is_last_stage: | ||
losses = [] | ||
pp_schedule.step(target=labels, losses=losses) | ||
else: | ||
schedule.step() | ||
|
||
# accumulate losses across pipeline microbatches | ||
loss = ( | ||
torch.mean(torch.stack(losses)) | ||
if is_last_stage | ||
else torch.Tensor([-1.0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need the default -1 value? because of logging purpose? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, yea i could make it a 'None' but then i have to update logger to not log at all. maybe that's actually a better way to do it. let me try that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok- so what I could do is try to alter the metrics code so that on non-last-stage ranks, we omit printing loss, or, we print "loss: None" instead of -1. The change will add more lines of code, since I need to deal with several places that expect loss and global_[avg/mean]_loss to be valid numbers
I agree in principle that's the "right" fix, but i'm not sure if its worth the LOC / complexity. I don't totally hate the Another option I considered is to skip the whole codeblock of '# log metrics' on non-last-stage ranks. I ruled this out, since it is still useful to log mfu, memory for other ranks. So let me know what you want to do here @wanchaol |
||
) | ||
else: | ||
# Non-PP forward / backward | ||
with loss_parallel_ctx(): | ||
pred = model(input_ids) | ||
loss = loss_fn(pred, labels) | ||
loss.backward() | ||
|
||
# clip gradients | ||
torch.nn.utils.clip_grad_norm_( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm new to PP api and have a question:
If
layers_per_rank
= 5,parallel_dims.pp
= 2, what should be the split_spec. My straightforward thought isSplitPoint.BEGINNING
should containi = 1, 3, 5
, but according to the code it's justi = 1
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parallel_dims.pp
refers to the number of pipeline stages we split the model into.For example, if
model.layers
= 10, 10 // 2 = 5, then we put 5 layers per stage (i.e.layers_per_rank = 5
).Hence we make a cut at
model.layers.5
-- (nRanks - 1) split points.