diff --git a/.github/workflows/unit_test_4gpu.yaml b/.github/workflows/unit_test_4gpu.yaml index 0088bb3ee..76ac024b8 100644 --- a/.github/workflows/unit_test_4gpu.yaml +++ b/.github/workflows/unit_test_4gpu.yaml @@ -36,6 +36,7 @@ jobs: pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 python -m pip install -r requirements.txt python -m pip install -r dev-requirements.txt + python -m pip install git+https://github.com/pytorch/pippy - name: Run test_runner.py run: python ./test_runner.py - name: Upload Coverage to Codecov diff --git a/test_runner.py b/test_runner.py index 80d75ad86..fad572bd1 100755 --- a/test_runner.py +++ b/test_runner.py @@ -26,6 +26,8 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + requires_seed_checkpoint: bool = False + ngpu: int = 4 CONFIG_DIR = "./train_configs" @@ -85,25 +87,104 @@ class OverrideDefinitions: ], "Checkpoint Integration Test - Save Model Weights Only bf16", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--checkpoint.folder {test_checkpoint_dir}_pp", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.data_parallel_degree 1", + "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + ], + ], + "PP 1D test", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--checkpoint.folder {test_checkpoint_dir}_pp_dp", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.data_parallel_degree 2", + "--model.norm_type fused_rmsnorm", + ], + ], + "PP+DP 2D test", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--checkpoint.folder {test_checkpoint_dir}_pp_tp", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.tensor_parallel_degree 2", + "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + ], + ], + "PP+TP 2D test", + requires_seed_checkpoint=True, + ), + # oh.. not enough GPUs? + # OverrideDefinitions( + # [ + # [ + # "--checkpoint.enable_checkpoint", + # f"--checkpoint.folder {test_checkpoint_dir}_pp_dp_tp", + # "--experimental.pipeline_parallel_degree 2", + # "--experimental.pipeline_parallel_split_points layers.1", + # "--training.data_parallel_degree 2", + # "--training.tensor_parallel_degree 2", + # "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + # ], + # ], + # "PP+DP+TP 3D test", + # requires_seed_checkpoint=True, + # ), ] +def _run_cmd(cmd): + return subprocess.run( + [cmd], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + + def run_test(test_flavor: OverrideDefinitions, full_path: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: - cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh" + + cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" if override_arg: cmd += " " + " ".join(override_arg) print( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) - result = subprocess.run( - [cmd], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - ) + + if test_flavor.requires_seed_checkpoint: + checkpoint_folder_arg = None + for arg in override_arg: + if "--checkpoint.folder" in arg: + checkpoint_folder_arg = arg + assert ( + checkpoint_folder_arg is not None + ), "Can't use seed checkpoint if folder is not specified" + print("Creating seed checkpoint") + result = _run_cmd( + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {checkpoint_folder_arg}" + ) + print(result.stdout) + + result = _run_cmd(cmd) print(result.stdout) if result.returncode != 0: raise Exception( diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1de3c82c9..4a5e375e6 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -17,6 +17,12 @@ from torchtitan.logging_utils import logger +def string_list(raw_arg): + s = raw_arg.split(",") + print(s) + return s + + class JobConfig: """ A helper class to manage the train configuration. @@ -202,11 +208,68 @@ def __init__(self): help="Whether to apply loss parallel when sequence parallel is enabled", ) self.parser.add_argument( - "--training.pipeline_parallel_degree", + "--experimental.pipeline_parallel_degree", type=int, default=1, help="Pipeline Parallelism degree. 1 means disabled.", ) + self.parser.add_argument( + "--experimental.pipeline_parallel_stages_per_rank", + type=int, + default=1, + help=""" + Pipeline Parallelism number of stages per rank (a.k.a. virtual stages) + + For simple schedules, this should be 1. + + For looped schedules, this can be greater than one. + + If the number of stages produced by splitting does not match the expected number of stages, + an error will be raised for sanity.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_points", + type=string_list, + nargs="+", + default=[], + help=""" + Specify comma-separated names of modules to use as the beginning of a split point. + + e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, + the first containing all the layers up to layers.0, + the second containing layers.0 and up to layers.2, + the third containing layers.2 and all the remaining layers. + + Note: fully-automated splitting may be enabled in the future, + but currently the split points must be specified manually for both manual and tracer.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_schedule", + type=str, + choices=["1f1b", "gpipe"], + default="1f1b", + help=""" + Specify the Pipeline Parallel schedule to use. + + The schedule must be compatible with the split points and stages_per_rank. + + Looped schedules are not yet supported in torchtitan.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_mode", + type=str, + choices=["manual", "tracer"], + default="manual", + help=""" + Specify the split method (e.g. the Pipeline Parallelism Front End) + + "manual" means each rank will construct an nn.Module with the appropriate layers and .forward + implementation manually, and then wrap it in a PipelineStage. + + "tracer" means the full model will be initialized (via meta device) and then traced into a graph, + split via the provided split points, unflattened into an nn.Module, + and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""", + ) self.parser.add_argument( "--training.compile", action="store_true", @@ -408,6 +471,10 @@ def parse_args_from_command_line( aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) + elif arg == "experimental.pipeline_parallel_split_points": + # type inference breaks here, since the type is just 'list' and it ends up flattening + # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] + aux_parser.add_argument("--" + arg, type=string_list) else: aux_parser.add_argument("--" + arg, type=type(val)) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index d69dad67a..d666c0b46 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -76,7 +76,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten """ ndim = x.ndim assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (freqs_cis.shape, x.shape) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 779be60e8..38de480fe 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,10 +8,26 @@ # llama model, i.e. activation checkpointing, etc. from collections import defaultdict -from typing import Tuple +from typing import Dict, List, Tuple import torch +# TODO(whc) this can be removed after pippy migration into pytorch core is complete. +try: + from pippy import ( + ManualPipelineStage, + pipeline, + Schedule1F1B, + ScheduleGPipe, + SplitPoint, + ) + from pippy.PipelineStage import _PipelineStage +except ImportError as exc: + raise ImportError( + "pippy is not installed. Please install it to use pipeline parallelism. " + "`pip install git+https://github.com/pytorch/pippy`" + ) from exc + from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -25,7 +41,7 @@ RowwiseParallel, SequenceParallel, ) - +from torch.nn import ModuleDict from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint from torchtitan.config_manager import JobConfig @@ -129,15 +145,281 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel +class TransformerChunk(torch.nn.Module): + def __init__( + self, + orig_model, # : Transformer, + this_stage_layer_names: List[str], + device, + input_seqlen: int, + ): + super().__init__() + self.tok_embeddings = None + + # inferring seqlen from forward(input) only works on stage0, bc on later stages + # the hidden state input may have reduced seqlen due to TP. We need to use the + # original (full) seqlen for freqs_cis to be correct. + self.input_seqlen = input_seqlen + + if "tok_embeddings" in this_stage_layer_names: + self.tok_embeddings = orig_model.tok_embeddings + + with torch.device(device): + self.freqs_cis = orig_model._precompute_freqs_cis() + + # preserve FQNs of original model by preserving structure + # (including preserving position in layers[] list)- use dummy module + self.layers = ModuleDict() + for name in this_stage_layer_names: + if "layers." in name: + idx = name.split(".")[-1] + self.layers[idx] = orig_model.layers[int(idx)] + self.norm = None + if "norm" in this_stage_layer_names: + self.norm = orig_model.norm + self.output = None + if "output" in this_stage_layer_names: + self.output = orig_model.output + + def forward(self, input): + """ + Copypaste of original Transformer.forward, with conditionals and unpacking added + such that we handle the cases where this rank doesn't have the embedding, or doesn't have + the output layers. + """ + if self.tok_embeddings: + h = self.tok_embeddings(input) + else: + h = input + + freqs_cis = self.freqs_cis[0 : self.input_seqlen] + + for layer in self.layers.values(): + h = layer(h, freqs_cis) + output = h + + if self.norm: + h = self.norm(h) + output = h + + if self.output: + output = self.output(h).float() + return output + + +def apply_pipeline_parallelism( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.experimental.pipeline_parallel_split_mode == "manual": + return apply_pipeline_parallelism_manual( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + elif job_config.experimental.pipeline_parallel_split_mode == "tracer": + return apply_pipeline_parallelism_tracer( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode" + ) + + +def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn): + if job_config.experimental.pipeline_parallel_schedule == "1f1b": + schedule_class = Schedule1F1B + elif job_config.experimental.pipeline_parallel_schedule == "gpipe": + schedule_class = ScheduleGPipe + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_schedule} is not implemented" + ) + return schedule_class( + stage, + n_microbatches=parallel_dims.pp, + loss_fn=loss_fn, + ) + + +def _llama_fqns(num_layers): + return ( + [ + "tok_embeddings", + ] + + [f"layers.{i}" for i in range(num_layers)] + + [ + "norm", + "output", + ] + ) + + +def split_stage_fqns(fqns, split_points, stage_id): + """Helper for splitting ordered list of layer names into layers per stage. + + split_points is a list of layer names, each layer will be the first layer in a stage + """ + stages = [] + cur = [] + + for name in fqns: + if name in split_points: + assert len( + cur + ), f"{name} is not a valid split point, do not specify the first layer of stage 0" + stages.append(cur) + cur = [] + cur.append(name) + + stages.append(cur) + print(f"Split using points {split_points}, got statges {stages}") + return stages[stage_id] + + +def apply_pipeline_parallelism_manual( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + """ + This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages). + + The SPMD parallelisms should be applied to + """ + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + # heuristically == PP dim but should be a config + microbatches = parallel_dims.pp + stage_idx = pp_rank # TODO support virtual stages + layers_per_rank = len(model.layers) // parallel_dims.pp + layer_offset = layers_per_rank * pp_rank + this_stage_layer_names = [ + f"layers.{i + layer_offset}" for i in range(layers_per_rank) + ] + if pp_rank == 0: + this_stage_layer_names.insert(0, "tok_embeddings") + assert "layers.0" in this_stage_layer_names + elif pp_rank == pp_size - 1: + this_stage_layer_names.append("norm") + this_stage_layer_names.append("output") + assert "layers.1" in this_stage_layer_names + + fqns = _llama_fqns(len(model.layers)) + new_names = split_stage_fqns( + fqns, job_config.experimental.pipeline_parallel_split_points, pp_rank + ) + assert len(new_names) == len(this_stage_layer_names), ( + len(new_names), + len(this_stage_layer_names), + ) + for n, n_ in zip(new_names, this_stage_layer_names): + assert n == n_, (n, n_) + + input_seqlen = 2048 # TODO hack + + model = TransformerChunk(model, this_stage_layer_names, device, input_seqlen) + # Create a pipeline representation from the model + + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and + # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the + # layers of the model that map to this stage, not the whole model. + + # Get example input + if pp_rank == 0: + input_shape = (job_config.training.batch_size, job_config.training.seq_len) + input = torch.randint( + model_config.vocab_size, input_shape, dtype=torch.int64, device=device + ) + + # HACK- can't use shape inference via execution of the PP stage inside ManualPipelineStage API, becuase the + # real output shapes will change after applying TP. So we hardcode output shapes here, and thus bypass doing + # shape inference. + # the real fix is to use lazy shape inference during first PP forward, and not need to specify anything here. + output_shape = ( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ) + output = torch.empty(output_shape, dtype=torch.float32, device=device) + else: + # TODO(whc) can we rely on shape inference so that user doesn't have to compute TP impact on seq_len + input_shape = ( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ) + input = torch.randint( + model_config.vocab_size, input_shape, dtype=torch.float32, device=device + ) + # TODO wrong shape, need to consider output layer + output_shape = ( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ) + output = torch.empty(output_shape, dtype=torch.float32, device=device) + + model.to_empty(device=device) + stage = ManualPipelineStage( + model, + pp_rank, + pp_size, + device, + microbatches, + input_args=input.chunk(microbatches)[0], + output_args=output.chunk(microbatches)[0], + group=pp_mesh.get_group("pp"), + ) + return (stage, model) + + +def apply_pipeline_parallelism_tracer( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + assert ( + parallel_dims.pp_enabled + ), "can't apply pipeline parallelism if it is not enabled" + + if job_config.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + ) + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + split_spec = { + f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, parallel_dims.pp) + } + # Get example input + input_shape = (job_config.training.batch_size, job_config.training.seq_len) + input_ids = torch.randint( + model.vocab_size, input_shape, dtype=torch.int64, device="meta" + ) + + # Create a pipeline representation from the model + pipe = pipeline( + model, parallel_dims.pp, example_args=(input_ids,), split_spec=split_spec + ) + model = pipe.get_stage_module(stage_idx) + stage = _PipelineStage( + stage_module=model, + stage_index=pp_rank, + pipe_info=pipe.pipe_info, + device=device, + group=pp_mesh.get_group(), + ) + return (stage, model) + + def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ - Apply parallelisms and activation checkpointing to the model. + Apply SPMD parallelisms and activation checkpointing to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") if parallel_dims.tp_enabled: if job_config.model.norm_type == "fused_rmsnorm": @@ -172,7 +454,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ) # Apply tensor + sequence parallelism to every transformer block - for layer_id, transformer_block in enumerate(model.layers): + for layer_name, transformer_block in model.layers.named_children(): layer_plan = { "attention": PrepareModuleInput( input_layouts=(Shard(1), None), @@ -211,24 +493,32 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names # TODO: Expose `reduce_dtype` as a config option. mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + # TODO(whc) need to fix PP + FSDP-mixed-precision + # tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs + # param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + param_dtype=torch.float32, + reduce_dtype=torch.float32, ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - for layer_id, transformer_block in enumerate(model.layers): + for layer_name, transformer_block in model.layers.named_children(): if job_config.activation_checkpoint.mode in ("full", "selective"): transformer_block = checkpoint_wrapper( transformer_block, job_config.activation_checkpoint ) # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately - reshard_after_forward = layer_id < len(model.layers) - 1 + # reshard_after_forward = layer_id < len(model.layers) - 1 + # TODO(whc) need to fix correctly handle layer-ids on pp-split module + reshard_after_forward = True fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) - model.layers[layer_id] = transformer_block + model.layers.add_module(layer_name, transformer_block) + + # TODO(whc) do we need reshard_after_forward setting here too? model = fully_shard(model, **fsdp_config) if ac_mode in ("full", "selective"): logger.info(f"Applied {ac_mode} activation checkpointing to the model") diff --git a/train.py b/train.py index 318c7174e..c0f3e517a 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,9 @@ import torch import torch.nn.functional as F + from torch.distributed import destroy_process_group +from torch.distributed._composable.fsdp.fully_shard import FSDPModule from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel @@ -122,11 +124,12 @@ def main(job_config: JobConfig): parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, tp=job_config.training.tensor_parallel_degree, - pp=job_config.training.pipeline_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) init_distributed(job_config) world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -144,6 +147,15 @@ def main(job_config: JobConfig): dp_rank = dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + pp_degree = pp_mesh.size() + pp_rank = pp_mesh.get_local_rank() + + else: + pp_degree, pp_rank = 1, 0 + data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -201,13 +213,32 @@ def loss_fn(pred, labels): # obtain the peak flops of bf16 type for MFU calculation gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - # apply PT-D parallelisms and activation checkpointing + if parallel_dims.pp_enabled: + from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism + + stage, model = apply_pipeline_parallelism( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + + # apply PT-D DP/TP parallelisms and activation checkpointing model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # allocate sharded model on GPU and initialize weights via DTensor + model.to_empty(device="cuda") - model.init_weights() + + if parallel_dims.pp_enabled: + from torchtitan.parallelisms.parallelize_llama import build_pipeline_schedule + + pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn) + else: + # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint + # and loading it to get initialization values. This is becuase the init_weights functions are written assuming + # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash + # becuase it can't find "embedding" layer, for example. + + # allocate sharded model on GPU and initialize weights via DTensor + model.init_weights() gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -216,10 +247,13 @@ def loss_fn(pred, labels): f"({gpu_mem_stats.max_reserved_pct:.2f}%)" ) + if isinstance(model, FSDPModule) and parallel_dims.pp_enabled: + # reshard now to counteract an issue where FSDP's states got advanced during PP stage shape inference + model.reshard() + # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) - metric_logger = build_metric_logger(job_config) # torch.compile model for improved performance @@ -257,7 +291,13 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint.load() + checkpoint_loaded = checkpoint.load() + + if parallel_dims.pp_enabled and not checkpoint_loaded: + raise RuntimeError( + "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " + "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" + ) # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. @@ -299,14 +339,33 @@ def loss_fn(pred, labels): input_ids = input_ids.cuda() labels = labels.cuda() - optimizer.zero_grad() - # forward / backward - with loss_parallel_ctx(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - loss.backward() + if parallel_dims.pp_enabled: + # pipeline parallel forward / backward inside step() call + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + + with loss_parallel_ctx(): + if pp_mesh.get_local_rank() == 0: + pp_schedule.step(input_ids) + elif is_last_stage: + losses = [] + pp_schedule.step(target=labels, losses=losses) + else: + schedule.step() + + # accumulate losses across pipeline microbatches + loss = ( + torch.mean(torch.stack(losses)) + if is_last_stage + else torch.Tensor([-1.0]) + ) + else: + # Non-PP forward / backward + with loss_parallel_ctx(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + loss.backward() # clip gradients torch.nn.utils.clip_grad_norm_( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 4541fec7b..009348b5c 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -36,11 +36,13 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) +[experimental] +pipeline_parallel_degree = 1 + [checkpoint] enable_checkpoint = false folder = "checkpoint"