Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pipeline Parallel (and 2D PP+FSDP) support #318

Merged
merged 36 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion create_seed_checkpoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ LOG_RANK=0
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}

seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint"
force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1"
force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1"
overrides=""
if [ $# -ne 0 ]; then
overrides="$*"
Expand Down
142 changes: 132 additions & 10 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class OverrideDefinitions:

override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
test_descr: str = "default"
requires_seed_checkpoint: bool = False
ngpu: int = 4


def build_test_list(args):
Expand All @@ -35,6 +37,82 @@ def build_test_list(args):
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_1f1b/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
"--model.norm_type fused_rmsnorm",
],
],
"PP 1D test 1f1b",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_gpipe/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
"--model.norm_type fused_rmsnorm",
],
],
"PP 1D test gpipe",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 2",
"--model.norm_type fused_rmsnorm",
],
],
"PP+DP 1f1b 2D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 2",
"--model.norm_type fused_rmsnorm",
],
],
"PP+DP gpipe 2D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_tp/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP+TP 2D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
Expand Down Expand Up @@ -96,26 +174,62 @@ def build_test_list(args):
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_dp_tp/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--training.data_parallel_degree 2",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
],
],
"PP+DP+TP 3D test",
requires_seed_checkpoint=True,
ngpu=8,
),
]
return integration_tests_flavors


def _run_cmd(cmd):
return subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)


def run_test(test_flavor: OverrideDefinitions, full_path: str):
# run_test supports sequence of tests.
for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
if override_arg:
cmd += " " + " ".join(override_arg)
print(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)

if test_flavor.requires_seed_checkpoint:
dump_folder_arg = None
for arg in override_arg:
if "--job.dump_folder" in arg:
dump_folder_arg = arg
assert (
dump_folder_arg is not None
), "Can't use seed checkpoint if folder is not specified"
print("Creating seed checkpoint")
wconstab marked this conversation as resolved.
Show resolved Hide resolved
result = _run_cmd(
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}"
)
print(result.stdout)

result = _run_cmd(cmd)
print(result.stdout)
if result.returncode != 0:
raise Exception(
Expand All @@ -135,13 +249,21 @@ def run_tests(args):
)
if is_integration_test:
for test_flavor in integration_tests_flavors[config_file]:
run_test(test_flavor, full_path)
if (args.ngpu == 8 and test_flavor.ngpu == 8) or (
args.ngpu == 4 and test_flavor.ngpu <= 4
):
run_test(test_flavor, full_path)
else:
print(
f"Skipping test {test_flavor} due to num_gpu mismatch {test_flavor.ngpu}"
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
parser.add_argument("output_dir", type=str)
parser.add_argument("--config_dir", default="./train_configs")
parser.add_argument("--ngpu", default=4, type=int)
args = parser.parse_args()

if not os.path.exists(args.output_dir):
Expand Down
58 changes: 56 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from torchtitan.logging_utils import logger


def string_list(raw_arg):
return raw_arg.split(",")


class JobConfig:
"""
A helper class to manage the train configuration.
Expand Down Expand Up @@ -202,10 +206,56 @@ def __init__(self):
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--training.pipeline_parallel_degree",
"--experimental.pipeline_parallel_degree",
type=int,
default=1,
help="Pipeline Parallelism degree. 1 means disabled.",
help="""
Pipeline Parallelism degree, or number of ranks. 1 means disabled.
If using looped schedules, this still specifies the number of physical ranks, not the number
of stages. Stages per rank are inferred from split points degree, and schedule.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_points",
wconstab marked this conversation as resolved.
Show resolved Hide resolved
type=string_list,
nargs="+",
default=[],
help="""
Specify comma-separated names of modules to use as the beginning of a split point.
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
the first containing all the layers up to layers.0,
the second containing layers.0 and up to layers.2,
the third containing layers.2 and all the remaining layers.
Note: fully-automated splitting may be enabled in the future,
but currently the split points must be specified manually for both manual and tracer.""",
wconstab marked this conversation as resolved.
Show resolved Hide resolved
)
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1f1b", "gpipe"],
default="1f1b",
help="""
Specify the Pipeline Parallel schedule to use.
The schedule must be compatible with the split points and stages_per_rank.
Looped schedules are not yet supported in torchtitan.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_mode",
type=str,
choices=["manual", "tracer"],
default="manual",
help="""
Specify the split method (e.g. the Pipeline Parallelism Front End)
"manual" means each rank will construct an nn.Module with the appropriate layers and .forward
implementation manually, and then wrap it in a PipelineStage.
"tracer" means the full model will be initialized (via meta device) and then traced into a graph,
split via the provided split points, unflattened into an nn.Module,
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
)
self.parser.add_argument(
"--training.compile",
Expand Down Expand Up @@ -408,6 +458,10 @@ def parse_args_from_command_line(
aux_parser.add_argument(
"--" + arg, action="store_true" if val else "store_false"
)
elif arg == "experimental.pipeline_parallel_split_points":
# type inference breaks here, since the type is just 'list' and it ends up flattening
wconstab marked this conversation as resolved.
Show resolved Hide resolved
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
else:
aux_parser.add_argument("--" + arg, type=type(val))

Expand Down
6 changes: 5 additions & 1 deletion torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

from torch.distributed.device_mesh import init_device_mesh
from torchtitan.logging_utils import logger
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama

models_parallelize_fns = {
"llama2": parallelize_llama,
"llama3": parallelize_llama,
}
models_pipelining_fns = {
"llama2": pipeline_llama,
"llama3": pipeline_llama,
}


@dataclass
Expand Down
Loading
Loading