Skip to content

Commit

Permalink
Add Pipeline Parallel (and 2D PP+FSDP) support
Browse files Browse the repository at this point in the history
runs PP+DP and PP+TP without issue,
runs PP+TP+DP with decreasing loss, but fails DCP save

Supports only simple schedules currently, gpipe and 1f1b.

Ads cmdline/toml arg for specifiying split points, in a unified
way between tracer or manual frontend.

  e.g. user can specifiy "layers.2,layers.4" as split points.

Currently uses manual frontend by default, but allows specifying
tracer frontend.  Tracer frontend requires working around additional
compatibility limitations, indicated by raising assertions, and is
not ready for wider use  yet.

ghstack-source-id: e49b659e66f4101cef58ad717a80521f5b172347
Pull Request resolved: #318
  • Loading branch information
wconstab committed May 10, 2024
1 parent f5a3ad7 commit c107fa9
Show file tree
Hide file tree
Showing 7 changed files with 534 additions and 34 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unit_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install -r requirements.txt
python -m pip install -r dev-requirements.txt
python -m pip install git+https://github.com/pytorch/pippy
- name: Run test_runner.py
run: python ./test_runner.py
- name: Upload Coverage to Codecov
Expand Down
97 changes: 89 additions & 8 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class OverrideDefinitions:

override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
test_descr: str = "default"
requires_seed_checkpoint: bool = False
ngpu: int = 4


CONFIG_DIR = "./train_configs"
Expand Down Expand Up @@ -85,25 +87,104 @@ class OverrideDefinitions:
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_pp",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP 1D test",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_pp_dp",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--training.data_parallel_degree 2",
"--model.norm_type fused_rmsnorm",
],
],
"PP+DP 2D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_pp_tp",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP+TP 2D test",
requires_seed_checkpoint=True,
),
# oh.. not enough GPUs?
# OverrideDefinitions(
# [
# [
# "--checkpoint.enable_checkpoint",
# f"--checkpoint.folder {test_checkpoint_dir}_pp_dp_tp",
# "--experimental.pipeline_parallel_degree 2",
# "--experimental.pipeline_parallel_split_points layers.1",
# "--training.data_parallel_degree 2",
# "--training.tensor_parallel_degree 2",
# "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
# ],
# ],
# "PP+DP+TP 3D test",
# requires_seed_checkpoint=True,
# ),
]


def _run_cmd(cmd):
return subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)


def run_test(test_flavor: OverrideDefinitions, full_path: str):
# run_test supports sequence of tests.
for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
if override_arg:
cmd += " " + " ".join(override_arg)
print(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)

if test_flavor.requires_seed_checkpoint:
checkpoint_folder_arg = None
for arg in override_arg:
if "--checkpoint.folder" in arg:
checkpoint_folder_arg = arg
assert (
checkpoint_folder_arg is not None
), "Can't use seed checkpoint if folder is not specified"
print("Creating seed checkpoint")
result = _run_cmd(
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {checkpoint_folder_arg}"
)
print(result.stdout)

result = _run_cmd(cmd)
print(result.stdout)
if result.returncode != 0:
raise Exception(
Expand Down
69 changes: 68 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from torchtitan.logging_utils import logger


def string_list(raw_arg):
s = raw_arg.split(",")
print(s)
return s


class JobConfig:
"""
A helper class to manage the train configuration.
Expand Down Expand Up @@ -202,11 +208,68 @@ def __init__(self):
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--training.pipeline_parallel_degree",
"--experimental.pipeline_parallel_degree",
type=int,
default=1,
help="Pipeline Parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_stages_per_rank",
type=int,
default=1,
help="""
Pipeline Parallelism number of stages per rank (a.k.a. virtual stages)
For simple schedules, this should be 1.
For looped schedules, this can be greater than one.
If the number of stages produced by splitting does not match the expected number of stages,
an error will be raised for sanity.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_points",
type=string_list,
nargs="+",
default=[],
help="""
Specify comma-separated names of modules to use as the beginning of a split point.
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
the first containing all the layers up to layers.0,
the second containing layers.0 and up to layers.2,
the third containing layers.2 and all the remaining layers.
Note: fully-automated splitting may be enabled in the future,
but currently the split points must be specified manually for both manual and tracer.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1f1b", "gpipe"],
default="1f1b",
help="""
Specify the Pipeline Parallel schedule to use.
The schedule must be compatible with the split points and stages_per_rank.
Looped schedules are not yet supported in torchtitan.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_mode",
type=str,
choices=["manual", "tracer"],
default="manual",
help="""
Specify the split method (e.g. the Pipeline Parallelism Front End)
"manual" means each rank will construct an nn.Module with the appropriate layers and .forward
implementation manually, and then wrap it in a PipelineStage.
"tracer" means the full model will be initialized (via meta device) and then traced into a graph,
split via the provided split points, unflattened into an nn.Module,
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
)
self.parser.add_argument(
"--training.compile",
action="store_true",
Expand Down Expand Up @@ -408,6 +471,10 @@ def parse_args_from_command_line(
aux_parser.add_argument(
"--" + arg, action="store_true" if val else "store_false"
)
elif arg == "experimental.pipeline_parallel_split_points":
# type inference breaks here, since the type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
else:
aux_parser.add_argument("--" + arg, type=type(val))

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (freqs_cis.shape, x.shape)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)

Expand Down
Loading

0 comments on commit c107fa9

Please sign in to comment.