diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 38bab219..1abc77ec 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -25,7 +25,7 @@ LOG_RANK=0 CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint" -force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1" +force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1" overrides="" if [ $# -ne 0 ]; then overrides="$*" diff --git a/test_runner.py b/test_runner.py index ca9d1320..dfd4a987 100755 --- a/test_runner.py +++ b/test_runner.py @@ -11,6 +11,8 @@ from dataclasses import dataclass from typing import Sequence +from torchtitan.logging_utils import logger + try: import tomllib except ModuleNotFoundError: @@ -25,6 +27,8 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + requires_seed_checkpoint: bool = False + ngpu: int = 4 def build_test_list(args): @@ -35,6 +39,78 @@ def build_test_list(args): """ integration_tests_flavors = defaultdict(list) integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_1f1b/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule 1f1b", + "--training.data_parallel_degree 1", + ], + ], + "PP 1D test 1f1b", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_gpipe/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule gpipe", + "--training.data_parallel_degree 1", + ], + ], + "PP 1D test gpipe", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule 1f1b", + "--training.data_parallel_degree 2", + ], + ], + "PP+DP 1f1b 2D test", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule gpipe", + "--training.data_parallel_degree 2", + ], + ], + "PP+DP gpipe 2D test", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_tp/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.tensor_parallel_degree 2", + "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP + ], + ], + "PP+TP 2D test", + requires_seed_checkpoint=True, + ), OverrideDefinitions( [ [ @@ -100,23 +176,43 @@ def build_test_list(args): return integration_tests_flavors +def _run_cmd(cmd): + return subprocess.run( + [cmd], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + + def run_test(test_flavor: OverrideDefinitions, full_path: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: - cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh" + + cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" if override_arg: cmd += " " + " ".join(override_arg) - print( + logger.info( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) - result = subprocess.run( - [cmd], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - ) - print(result.stdout) + + if test_flavor.requires_seed_checkpoint: + dump_folder_arg = None + for arg in override_arg: + if "--job.dump_folder" in arg: + dump_folder_arg = arg + assert ( + dump_folder_arg is not None + ), "Can't use seed checkpoint if folder is not specified" + logger.info("Creating seed checkpoint") + result = _run_cmd( + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}" + ) + logger.info(result.stdout) + + result = _run_cmd(cmd) + logger.info(result.stdout) if result.returncode != 0: raise Exception( f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}" diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1a3e36d4..da80b425 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -25,6 +25,10 @@ } +def string_list(raw_arg): + return raw_arg.split(",") + + class JobConfig: """ A helper class to manage the train configuration. @@ -210,10 +214,68 @@ def __init__(self): help="Whether to apply loss parallel when sequence parallel is enabled", ) self.parser.add_argument( - "--training.pipeline_parallel_degree", + "--experimental.pipeline_parallel_degree", type=int, default=1, - help="Pipeline Parallelism degree. 1 means disabled.", + help=""" + Pipeline Parallelism degree, or number of ranks. 1 means disabled. + If using looped schedules, this still specifies the number of physical ranks, not the number + of stages. Stages per rank are inferred from split points degree, and schedule.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_points", + type=string_list, + nargs="+", + default=[], + help=""" + Specify comma-separated names of modules to use as the beginning of a split point. + + e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, + the first containing all the layers up to layers.0, + the second containing layers.0 and up to layers.2, + the third containing layers.2 and all the remaining layers. + + Note: fully-automated splitting may be enabled in the future, + but currently the split points must be specified manually for both manual and tracer.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_schedule", + type=str, + choices=["1f1b", "gpipe"], + default="1f1b", + help=""" + Specify the Pipeline Parallel schedule to use. + + The schedule must be compatible with the split points and stages_per_rank. + + Looped schedules are not yet supported in torchtitan.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_mode", + type=str, + choices=["manual", "tracer"], + default="manual", + help=""" + Specify the split method (e.g. the Pipeline Parallelism Front End) + + "manual" means each rank will construct an nn.Module with the appropriate layers and .forward + implementation manually, and then wrap it in a PipelineStage. + + "tracer" means the full model will be initialized (via meta device) and then traced into a graph, + split via the provided split points, unflattened into an nn.Module, + and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_microbatches", + type=int, + default=None, + help=""" + How many microbatches to split the global training batch into when using pipeline parallelism. + + The global training batch size must be evenly divisible by the number of microbatches. + + The default value will be the number of pipeline stages, if unspecified. + """, ) self.parser.add_argument( "--training.mixed_precision_param", @@ -437,6 +499,11 @@ def parse_args_from_command_line( aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) + elif arg == "experimental.pipeline_parallel_split_points": + # without this special case, type inference breaks here, + # since the inferred type is just 'list' and it ends up flattening + # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] + aux_parser.add_argument("--" + arg, type=string_list) else: aux_parser.add_argument("--" + arg, type=type(val)) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index e791b832..7e1b21c7 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -9,12 +9,16 @@ from torch.distributed.device_mesh import init_device_mesh from torchtitan.logging_utils import logger -from torchtitan.parallelisms.parallelize_llama import parallelize_llama +from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama models_parallelize_fns = { "llama2": parallelize_llama, "llama3": parallelize_llama, } +models_pipelining_fns = { + "llama2": pipeline_llama, + "llama3": pipeline_llama, +} @dataclass diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 0bd0a966..61cf79fe 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,7 +8,7 @@ # llama model, i.e. activation checkpointing, etc. from collections import defaultdict -from typing import Tuple +from typing import Dict, Tuple import torch @@ -18,6 +18,11 @@ checkpoint_wrapper as ptd_checkpoint_wrapper, CheckpointImpl, ) +from torch.distributed.pipelining import pipeline, SplitPoint +from torch.distributed.pipelining._PipelineStage import ( + _PipelineStage, + ManualPipelineStage, +) from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -31,7 +36,6 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import logger - # for selective AC no_recompute_list = { torch.ops.aten.mm.default, @@ -129,15 +133,165 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel +def pipeline_llama( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.experimental.pipeline_parallel_split_mode == "manual": + return pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + elif job_config.experimental.pipeline_parallel_split_mode == "tracer": + return pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode" + ) + + +def _llama_trace_input(job_config, model_config, device="meta"): + """Get meta tensors with the right input shapes used for tracing""" + tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) + tokens = torch.randint( + model_config.vocab_size, tokens_shape, dtype=torch.int64, device=device + ) + return (tokens,) + + +def pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + """ + This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. + + It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. + + The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD + parallelism. + """ + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + microbatches = ( + job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp + ) + stage_idx = pp_rank + + splits = job_config.experimental.pipeline_parallel_split_points + start_layer = splits[stage_idx - 1] if stage_idx > 0 else None + stop_layer = splits[stage_idx] if stage_idx < pp_size - 1 else None + + if pp_rank > 0: + model.tok_embeddings = None + + drop_layers = True + for name in list(model.layers.keys()): + # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) + if start_layer is None or f"layers.{name}" == start_layer: + drop_layers = False + if stop_layer is not None and f"layers.{name}" == stop_layer: + drop_layers = True + if drop_layers: + del model.layers[name] + + if pp_rank < pp_size - 1: + model.norm = None + model.output = None + + logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}") + + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and + # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the + # layers of the model that map to this stage, not the whole model. + mp_arg = job_config.training.mixed_precision_param + mp_dtype = TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else torch.float32 + batch_size = job_config.training.batch_size + local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) + layers_io_shape = (batch_size, local_seq_len, model_config.dim) + output_layer_shape = (batch_size, local_seq_len, model_config.vocab_size) + if pp_rank == 0: + # first layer + input = torch.randint( + model_config.vocab_size, + size=(batch_size, job_config.training.seq_len), + dtype=torch.int64, + device=device, + ) + else: + # later layers (assume all start w/ a transformer layer) + input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) + + if pp_rank == pp_size - 1: + # last layer + output = torch.rand(output_layer_shape, dtype=torch.float32, device=device) + else: + # earlier layers (assume all end in a transformer layer) + output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) + + model.to_empty(device=device) + stage = ManualPipelineStage( + model, + pp_rank, + pp_size, + device, + microbatches, + input_args=input.chunk(microbatches)[0], + output_args=output.chunk(microbatches)[0], + group=pp_mesh.get_group("pp"), + ) + return (stage, model) + + +def pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + ) + + # TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes? + raise NotImplementedError( + "pipeline tracer doesn't work with fsdp mixed precision currently. " + "To work around, edit fsdp mixed precision config to use fp32." + ) + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + split_spec = { + f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, parallel_dims.pp) + } + + # Create a pipeline representation from the model + pipe = pipeline( + model, + job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp, + example_args=_llama_trace_input(job_config, model_config), + split_spec=split_spec, + ) + model = pipe.get_stage_module(stage_idx) + stage = _PipelineStage( + stage_module=model, + stage_index=pp_rank, + pipe_info=pipe.pipe_info, + device=device, + group=pp_mesh.get_group(), + ) + return (stage, model) + + def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ - Apply parallelisms and activation checkpointing to the model. + Apply SPMD parallelisms and activation checkpointing to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") if parallel_dims.tp_enabled: if job_config.model.norm_type == "fused_rmsnorm": @@ -221,15 +375,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): transformer_block, job_config.activation_checkpoint ) # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 + # transformer block since FSDP would prefetch it immediately. + # When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings + # per microbatch. + reshard_after_forward = ( + int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled + ) fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) model.layers[layer_id] = transformer_block - model = fully_shard(model, **fsdp_config) + + model = fully_shard( + model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled + ) if ac_mode in ("full", "selective"): logger.info(f"Applied {ac_mode} activation checkpointing to the model") logger.info("Applied FSDP to the model") diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py new file mode 100644 index 00000000..24752e4b --- /dev/null +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torch.distributed.pipelining import Schedule1F1B, ScheduleGPipe +from torchtitan.logging_utils import logger + + +def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn): + if job_config.experimental.pipeline_parallel_schedule == "1f1b": + schedule_class = Schedule1F1B + elif job_config.experimental.pipeline_parallel_schedule == "gpipe": + schedule_class = ScheduleGPipe + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_schedule} is not implemented" + ) + logger.info( + f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}" + ) + return schedule_class( + stage, + n_microbatches=stage.chunks, + loss_fn=loss_fn, + ) diff --git a/train.py b/train.py index a0bb337e..e13acb3d 100644 --- a/train.py +++ b/train.py @@ -32,7 +32,12 @@ from torchtitan.lr_scheduling import get_lr_scheduler from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtitan.parallelisms import models_parallelize_fns, ParallelDims +from torchtitan.parallelisms import ( + models_parallelize_fns, + models_pipelining_fns, + ParallelDims, +) +from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule from torchtitan.profiling import maybe_enable_profiling from torchtitan.utils import ( Color, @@ -122,11 +127,12 @@ def main(job_config: JobConfig): parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, tp=job_config.training.tensor_parallel_degree, - pp=job_config.training.pipeline_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) init_distributed(job_config) world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -144,6 +150,10 @@ def main(job_config: JobConfig): dp_rank = dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -201,13 +211,26 @@ def loss_fn(pred, labels): # obtain the peak flops of bf16 type for MFU calculation gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - # apply PT-D parallelisms and activation checkpointing + if parallel_dims.pp_enabled: + stage, model = models_pipelining_fns[model_name]( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + + # apply PT-D DP/TP parallelisms and activation checkpointing model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # allocate sharded model on GPU and initialize weights via DTensor + model.to_empty(device="cuda") - model.init_weights() + + if parallel_dims.pp_enabled: + pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn) + else: + # If PP is enabled, we can't rely on init_weights, because some layers are missing. + # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation. + + # allocate sharded model on GPU and initialize weights via DTensor + model.init_weights() gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -258,7 +281,13 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint.load() + checkpoint_loaded = checkpoint.load() + + if parallel_dims.pp_enabled and not checkpoint_loaded: + raise RuntimeError( + "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " + "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" + ) # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. @@ -300,14 +329,33 @@ def loss_fn(pred, labels): input_ids = input_ids.cuda() labels = labels.cuda() - optimizer.zero_grad() - # forward / backward - with loss_parallel_ctx(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - loss.backward() + if parallel_dims.pp_enabled: + # pipeline parallel forward / backward inside step() call + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + + with loss_parallel_ctx(): + if pp_mesh.get_local_rank() == 0: + pp_schedule.step(input_ids) + elif is_last_stage: + losses = [] + pp_schedule.step(target=labels, losses=losses) + else: + pp_schedule.step() + + # accumulate losses across pipeline microbatches + loss = ( + torch.mean(torch.stack(losses)) + if is_last_stage + else torch.Tensor([-1.0]) + ) + else: + # Non-PP forward / backward + with loss_parallel_ctx(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + loss.backward() # clip gradients torch.nn.utils.clip_grad_norm_( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index b8ec566f..a0925de9 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -36,11 +36,13 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) +[experimental] +pipeline_parallel_degree = 1 + [checkpoint] enable_checkpoint = false folder = "checkpoint"