Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pipeline Parallel (and 2D PP+FSDP) support #161

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 31 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/unit_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install -r requirements.txt
python -m pip install -r dev-requirements.txt
python -m pip install git+https://github.com/pytorch/pippy
- name: Run test_runner.py
run: python ./test_runner.py
- name: Upload Coverage to Codecov
Expand Down
65 changes: 57 additions & 8 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class OverrideDefinitions:

override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
test_descr: str = "default"
requires_seed_checkpoint: bool = False
ngpu: int = 4


CONFIG_DIR = "./train_configs"
Expand Down Expand Up @@ -85,25 +87,72 @@ class OverrideDefinitions:
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_pp",
"--training.pipeline_parallel_degree 2",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP 1D test",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_pp_dp",
"--training.pipeline_parallel_degree 2",
"--training.data_parallel_degree 2",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP+DP 2D test",
requires_seed_checkpoint=True,
),
]


def _run_cmd(cmd):
return subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)


def run_test(test_flavor: OverrideDefinitions, full_path: str):
# run_test supports sequence of tests.
for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
if override_arg:
cmd += " " + " ".join(override_arg)
print(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)

if test_flavor.requires_seed_checkpoint:
checkpoint_folder_arg = None
for arg in override_arg:
if "--checkpoint.folder" in arg:
checkpoint_folder_arg = arg
assert (
checkpoint_folder_arg is not None
), "Can't use seed checkpoint if folder is not specified"
print("Creating seed checkpoint")
result = _run_cmd(
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {checkpoint_folder_arg}"
)
print(result.stdout)

result = _run_cmd(cmd)
print(result.stdout)
if result.returncode != 0:
raise Exception(
Expand Down
63 changes: 56 additions & 7 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@

import torch

# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
try:
from pippy import pipeline, SplitPoint
except ImportError as exc:
raise ImportError(
"pippy is not installed. Please install it to use pipeline parallelism. "
"`pip install git+https://github.com/pytorch/pippy`"
) from exc

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -129,15 +138,48 @@ def get_tp_parallel_strategy(
return RowwiseParallel, ColwiseParallel


def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig):
assert (
parallel_dims.pp_enabled
), "can't apply pipeline parallelism if it is not enabled"

if job_config.model.norm_type == "fused_rmsnorm":
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
raise NotImplementedError(
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
)
pp_mesh = world_mesh["pp"]
stage_idx = pp_mesh.get_local_rank()
layers_per_rank = len(model.layers) // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm new to PP api and have a question:
If layers_per_rank = 5, parallel_dims.pp = 2, what should be the split_spec. My straightforward thought is SplitPoint.BEGINNING should contain i = 1, 3, 5, but according to the code it's just i = 1.

Copy link
Contributor

@kwen2501 kwen2501 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parallel_dims.pp refers to the number of pipeline stages we split the model into.
For example, if model.layers = 10, 10 // 2 = 5, then we put 5 layers per stage (i.e. layers_per_rank = 5).
Hence we make a cut at model.layers.5 -- (nRanks - 1) split points.

}
# Get example input
label_shape = input_shape = (8, 2048) # TODO
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 any ideas for a clean way to do this in torchtrain? do we expect people to get a batch out of their dataloader and then reset it? or do we expect people to hardcode it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think what i might do is directly pass input_shape from train.py,

and in train.py i can set input_shape = (job_config.batch_size, job_config.seq_len) or something. is that clean enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok pushed a variation on this.

not sure if its better to hide this inside parallelize since we already have job config, or make it explicit from train.py that we are passing input_shape in for some reason

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way sounds okay to me -- eventually, the shape comes the config.

input_ids = torch.randint(
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
)
labels = torch.randint(
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
)

# Create a pipeline representation from the model
pipe = pipeline(
model, parallel_dims.pp, example_args=(input_ids,), split_spec=split_spec
)
model = pipe.get_stage_module(stage_idx)
return model, pipe.pipe_info


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms and activation checkpointing to the model.
Apply SPMD parallelisms and activation checkpointing to the model.

NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")

if parallel_dims.tp_enabled:
if job_config.model.norm_type == "fused_rmsnorm":
Expand Down Expand Up @@ -215,24 +257,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
# TODO(whc) need to fix PP + FSDP-mixed-precision
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=torch.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't by default change this, this would make the cases where FSDP or FSDP + TP use fp32 instead of bf16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if supporting bf16 should be a criteria for landing. I would imagine that training with FSDP + PP in fp32 is not really viable efficiency-wise (at least for larger jobs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should fix this before landing the PP change. I think there was a possible way to fix this in the tracer, but lost track of it, will dig it up

reduce_dtype=torch.float32,
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
for layer_name, transformer_block in model.layers.named_children():
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
# reshard_after_forward = layer_id < len(model.layers) - 1
# TODO(whc) need to fix correctly handle layer-ids on pp-split module
reshard_after_forward = True
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model.layers.add_module(layer_name, transformer_block)

model = fully_shard(model, **fsdp_config)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
Expand Down
100 changes: 88 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@

import torch
import torch.nn.functional as F

# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
try:
from pippy import ScheduleGPipe
from pippy.PipelineStage import _PipelineStage
except ImportError as exc:
raise ImportError(
"pippy is not installed. Please install it to use pipeline parallelism. "
"`pip install git+https://github.com/pytorch/pippy`"
) from exc

from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
Expand Down Expand Up @@ -126,7 +137,8 @@ def main(job_config: JobConfig):
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
init_distributed(job_config)

world_mesh = parallel_dims.build_mesh(device_type="cuda")
Expand All @@ -144,6 +156,15 @@ def main(job_config: JobConfig):
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0

if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]
pp_degree = pp_mesh.size()
pp_rank = pp_mesh.get_local_rank()

else:
pp_degree, pp_rank = 1, 0

data_loader = build_hf_data_loader(
job_config.training.dataset,
job_config.training.dataset_path,
Expand Down Expand Up @@ -201,13 +222,44 @@ def loss_fn(pred, labels):
# obtain the peak flops of bf16 type for MFU calculation
gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name)

# apply PT-D parallelisms and activation checkpointing
if parallel_dims.pp_enabled:
# TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern`
from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism

model, pipe_info = apply_pipeline_parallelism(
model, world_mesh, parallel_dims, job_config
)

# apply PT-D DP/TP parallelisms and activation checkpointing
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
# allocate sharded model on GPU and initialize weights via DTensor

model.to_empty(device="cuda")
model.init_weights()

# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
# there are virtual stages
if parallel_dims.pp_enabled:
stage = _PipelineStage(
stage_module=model,
stage_index=pp_rank,
pipe_info=pipe_info,
device=device,
group=pp_mesh.get_group(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if we should put the stage creation into parallelize_llama, IMO we only need pp_schedule in train.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I think this question and Ke's suggestion about returning a PipelineStage from parallelize_llama are better taken in context of a next PR that also adds support for looped schedules.

Looped schedules further complicate things bc the PP logic first needs to chunk up the model, then apply the DP/TP portion of parallelize_llama on each chunk, and finally pass all the chunks into the schedule.

I think in the end, I might prefer to separate out PP from parallelize_llama, and have a flow where we can take the return from PP apply function and iteratively call parallelize_llama on those chunks.

)
pp_schedule = ScheduleGPipe(
stage,
n_microbatches=parallel_dims.pp,
loss_fn=loss_fn,
)
else:
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
# becuase it can't find "embedding" layer, for example.

# allocate sharded model on GPU and initialize weights via DTensor
model.init_weights()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
Expand All @@ -219,7 +271,6 @@ def loss_fn(pred, labels):
# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
Expand Down Expand Up @@ -257,7 +308,13 @@ def loss_fn(pred, labels):
logger.info("Created seed checkpoint")
return

checkpoint.load()
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
raise RuntimeError(
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
)

# plot losses loaded from checkpoint (if any) to TensorBoard
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
Expand Down Expand Up @@ -299,14 +356,33 @@ def loss_fn(pred, labels):

input_ids = input_ids.cuda()
labels = labels.cuda()

optimizer.zero_grad()

# forward / backward
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with loss_parallel_ctx():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
schedule.step()

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need the default -1 value? because of logging purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yea i could make it a 'None' but then i have to update logger to not log at all. maybe that's actually a better way to do it. let me try that.

Copy link
Contributor Author

@wconstab wconstab May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok- so what I could do is try to alter the metrics code so that on non-last-stage ranks, we omit printing loss, or, we print "loss: None" instead of -1.

The change will add more lines of code, since I need to deal with several places that expect loss and global_[avg/mean]_loss to be valid numbers

  • avoid writing them into metrics dict
  • replace their format string with a string value instead of a float value in the logger.info
  • avoid calling loss.item() in the first place

I agree in principle that's the "right" fix, but i'm not sure if its worth the LOC / complexity. I don't totally hate the -1 thing.

Another option I considered is to skip the whole codeblock of '# log metrics' on non-last-stage ranks. I ruled this out, since it is still useful to log mfu, memory for other ranks.

So let me know what you want to do here @wanchaol

)
else:
# Non-PP forward / backward
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
Expand Down
Loading