From 5c7d155f2abdfcdd14de29b81a6b066669060385 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 21 May 2024 16:11:39 -0700 Subject: [PATCH] Add Pipeline Parallel (and 2D PP+FSDP) support runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save Supports only simple schedules currently, gpipe and 1f1b. Ads cmdline/toml arg for specifiying split points, in a unified way between tracer or manual frontend. e.g. user can specifiy "layers.2,layers.4" as split points. Currently uses manual frontend by default, but allows specifying tracer frontend. Tracer frontend requires working around additional compatibility limitations, indicated by raising assertions, and is not ready for wider use yet. ghstack-source-id: d7e0a1342bc97d6f1bba9e647234d90688ad708f Pull Request resolved: https://github.com/pytorch/torchtitan/pull/318 --- create_seed_checkpoint.sh | 2 +- test_runner.py | 116 ++++++++++-- torchtitan/config_manager.py | 71 +++++++- torchtitan/parallelisms/__init__.py | 6 +- torchtitan/parallelisms/parallelize_llama.py | 177 ++++++++++++++++++- torchtitan/parallelisms/pipelining_utils.py | 26 +++ train.py | 74 ++++++-- train_configs/debug_model.toml | 4 +- 8 files changed, 440 insertions(+), 36 deletions(-) create mode 100644 torchtitan/parallelisms/pipelining_utils.py diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 38bab219f..1abc77ec5 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -25,7 +25,7 @@ LOG_RANK=0 CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint" -force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1" +force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1" overrides="" if [ $# -ne 0 ]; then overrides="$*" diff --git a/test_runner.py b/test_runner.py index ca9d13209..dfd4a987e 100755 --- a/test_runner.py +++ b/test_runner.py @@ -11,6 +11,8 @@ from dataclasses import dataclass from typing import Sequence +from torchtitan.logging_utils import logger + try: import tomllib except ModuleNotFoundError: @@ -25,6 +27,8 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + requires_seed_checkpoint: bool = False + ngpu: int = 4 def build_test_list(args): @@ -35,6 +39,78 @@ def build_test_list(args): """ integration_tests_flavors = defaultdict(list) integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_1f1b/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule 1f1b", + "--training.data_parallel_degree 1", + ], + ], + "PP 1D test 1f1b", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_gpipe/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule gpipe", + "--training.data_parallel_degree 1", + ], + ], + "PP 1D test gpipe", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule 1f1b", + "--training.data_parallel_degree 2", + ], + ], + "PP+DP 1f1b 2D test", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule gpipe", + "--training.data_parallel_degree 2", + ], + ], + "PP+DP gpipe 2D test", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_tp/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.tensor_parallel_degree 2", + "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP + ], + ], + "PP+TP 2D test", + requires_seed_checkpoint=True, + ), OverrideDefinitions( [ [ @@ -100,23 +176,43 @@ def build_test_list(args): return integration_tests_flavors +def _run_cmd(cmd): + return subprocess.run( + [cmd], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + + def run_test(test_flavor: OverrideDefinitions, full_path: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: - cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh" + + cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" if override_arg: cmd += " " + " ".join(override_arg) - print( + logger.info( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) - result = subprocess.run( - [cmd], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - ) - print(result.stdout) + + if test_flavor.requires_seed_checkpoint: + dump_folder_arg = None + for arg in override_arg: + if "--job.dump_folder" in arg: + dump_folder_arg = arg + assert ( + dump_folder_arg is not None + ), "Can't use seed checkpoint if folder is not specified" + logger.info("Creating seed checkpoint") + result = _run_cmd( + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}" + ) + logger.info(result.stdout) + + result = _run_cmd(cmd) + logger.info(result.stdout) if result.returncode != 0: raise Exception( f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}" diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1a3e36d40..da80b4255 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -25,6 +25,10 @@ } +def string_list(raw_arg): + return raw_arg.split(",") + + class JobConfig: """ A helper class to manage the train configuration. @@ -210,10 +214,68 @@ def __init__(self): help="Whether to apply loss parallel when sequence parallel is enabled", ) self.parser.add_argument( - "--training.pipeline_parallel_degree", + "--experimental.pipeline_parallel_degree", type=int, default=1, - help="Pipeline Parallelism degree. 1 means disabled.", + help=""" + Pipeline Parallelism degree, or number of ranks. 1 means disabled. + If using looped schedules, this still specifies the number of physical ranks, not the number + of stages. Stages per rank are inferred from split points degree, and schedule.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_points", + type=string_list, + nargs="+", + default=[], + help=""" + Specify comma-separated names of modules to use as the beginning of a split point. + + e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, + the first containing all the layers up to layers.0, + the second containing layers.0 and up to layers.2, + the third containing layers.2 and all the remaining layers. + + Note: fully-automated splitting may be enabled in the future, + but currently the split points must be specified manually for both manual and tracer.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_schedule", + type=str, + choices=["1f1b", "gpipe"], + default="1f1b", + help=""" + Specify the Pipeline Parallel schedule to use. + + The schedule must be compatible with the split points and stages_per_rank. + + Looped schedules are not yet supported in torchtitan.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_mode", + type=str, + choices=["manual", "tracer"], + default="manual", + help=""" + Specify the split method (e.g. the Pipeline Parallelism Front End) + + "manual" means each rank will construct an nn.Module with the appropriate layers and .forward + implementation manually, and then wrap it in a PipelineStage. + + "tracer" means the full model will be initialized (via meta device) and then traced into a graph, + split via the provided split points, unflattened into an nn.Module, + and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_microbatches", + type=int, + default=None, + help=""" + How many microbatches to split the global training batch into when using pipeline parallelism. + + The global training batch size must be evenly divisible by the number of microbatches. + + The default value will be the number of pipeline stages, if unspecified. + """, ) self.parser.add_argument( "--training.mixed_precision_param", @@ -437,6 +499,11 @@ def parse_args_from_command_line( aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) + elif arg == "experimental.pipeline_parallel_split_points": + # without this special case, type inference breaks here, + # since the inferred type is just 'list' and it ends up flattening + # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] + aux_parser.add_argument("--" + arg, type=string_list) else: aux_parser.add_argument("--" + arg, type=type(val)) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index e791b832a..7e1b21c79 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -9,12 +9,16 @@ from torch.distributed.device_mesh import init_device_mesh from torchtitan.logging_utils import logger -from torchtitan.parallelisms.parallelize_llama import parallelize_llama +from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama models_parallelize_fns = { "llama2": parallelize_llama, "llama3": parallelize_llama, } +models_pipelining_fns = { + "llama2": pipeline_llama, + "llama3": pipeline_llama, +} @dataclass diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 0bd0a9661..61cf79fe3 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,7 +8,7 @@ # llama model, i.e. activation checkpointing, etc. from collections import defaultdict -from typing import Tuple +from typing import Dict, Tuple import torch @@ -18,6 +18,11 @@ checkpoint_wrapper as ptd_checkpoint_wrapper, CheckpointImpl, ) +from torch.distributed.pipelining import pipeline, SplitPoint +from torch.distributed.pipelining._PipelineStage import ( + _PipelineStage, + ManualPipelineStage, +) from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -31,7 +36,6 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import logger - # for selective AC no_recompute_list = { torch.ops.aten.mm.default, @@ -129,15 +133,165 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel +def pipeline_llama( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.experimental.pipeline_parallel_split_mode == "manual": + return pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + elif job_config.experimental.pipeline_parallel_split_mode == "tracer": + return pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode" + ) + + +def _llama_trace_input(job_config, model_config, device="meta"): + """Get meta tensors with the right input shapes used for tracing""" + tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) + tokens = torch.randint( + model_config.vocab_size, tokens_shape, dtype=torch.int64, device=device + ) + return (tokens,) + + +def pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + """ + This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. + + It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. + + The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD + parallelism. + """ + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + microbatches = ( + job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp + ) + stage_idx = pp_rank + + splits = job_config.experimental.pipeline_parallel_split_points + start_layer = splits[stage_idx - 1] if stage_idx > 0 else None + stop_layer = splits[stage_idx] if stage_idx < pp_size - 1 else None + + if pp_rank > 0: + model.tok_embeddings = None + + drop_layers = True + for name in list(model.layers.keys()): + # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) + if start_layer is None or f"layers.{name}" == start_layer: + drop_layers = False + if stop_layer is not None and f"layers.{name}" == stop_layer: + drop_layers = True + if drop_layers: + del model.layers[name] + + if pp_rank < pp_size - 1: + model.norm = None + model.output = None + + logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}") + + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and + # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the + # layers of the model that map to this stage, not the whole model. + mp_arg = job_config.training.mixed_precision_param + mp_dtype = TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else torch.float32 + batch_size = job_config.training.batch_size + local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) + layers_io_shape = (batch_size, local_seq_len, model_config.dim) + output_layer_shape = (batch_size, local_seq_len, model_config.vocab_size) + if pp_rank == 0: + # first layer + input = torch.randint( + model_config.vocab_size, + size=(batch_size, job_config.training.seq_len), + dtype=torch.int64, + device=device, + ) + else: + # later layers (assume all start w/ a transformer layer) + input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) + + if pp_rank == pp_size - 1: + # last layer + output = torch.rand(output_layer_shape, dtype=torch.float32, device=device) + else: + # earlier layers (assume all end in a transformer layer) + output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) + + model.to_empty(device=device) + stage = ManualPipelineStage( + model, + pp_rank, + pp_size, + device, + microbatches, + input_args=input.chunk(microbatches)[0], + output_args=output.chunk(microbatches)[0], + group=pp_mesh.get_group("pp"), + ) + return (stage, model) + + +def pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + ) + + # TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes? + raise NotImplementedError( + "pipeline tracer doesn't work with fsdp mixed precision currently. " + "To work around, edit fsdp mixed precision config to use fp32." + ) + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + split_spec = { + f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, parallel_dims.pp) + } + + # Create a pipeline representation from the model + pipe = pipeline( + model, + job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp, + example_args=_llama_trace_input(job_config, model_config), + split_spec=split_spec, + ) + model = pipe.get_stage_module(stage_idx) + stage = _PipelineStage( + stage_module=model, + stage_index=pp_rank, + pipe_info=pipe.pipe_info, + device=device, + group=pp_mesh.get_group(), + ) + return (stage, model) + + def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ - Apply parallelisms and activation checkpointing to the model. + Apply SPMD parallelisms and activation checkpointing to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") if parallel_dims.tp_enabled: if job_config.model.norm_type == "fused_rmsnorm": @@ -221,15 +375,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): transformer_block, job_config.activation_checkpoint ) # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 + # transformer block since FSDP would prefetch it immediately. + # When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings + # per microbatch. + reshard_after_forward = ( + int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled + ) fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) model.layers[layer_id] = transformer_block - model = fully_shard(model, **fsdp_config) + + model = fully_shard( + model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled + ) if ac_mode in ("full", "selective"): logger.info(f"Applied {ac_mode} activation checkpointing to the model") logger.info("Applied FSDP to the model") diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py new file mode 100644 index 000000000..24752e4b0 --- /dev/null +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torch.distributed.pipelining import Schedule1F1B, ScheduleGPipe +from torchtitan.logging_utils import logger + + +def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn): + if job_config.experimental.pipeline_parallel_schedule == "1f1b": + schedule_class = Schedule1F1B + elif job_config.experimental.pipeline_parallel_schedule == "gpipe": + schedule_class = ScheduleGPipe + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_schedule} is not implemented" + ) + logger.info( + f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}" + ) + return schedule_class( + stage, + n_microbatches=stage.chunks, + loss_fn=loss_fn, + ) diff --git a/train.py b/train.py index 318c7174e..90a745e5a 100644 --- a/train.py +++ b/train.py @@ -32,7 +32,12 @@ from torchtitan.lr_scheduling import get_lr_scheduler from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtitan.parallelisms import models_parallelize_fns, ParallelDims +from torchtitan.parallelisms import ( + models_parallelize_fns, + models_pipelining_fns, + ParallelDims, +) +from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule from torchtitan.profiling import maybe_enable_profiling from torchtitan.utils import ( Color, @@ -122,11 +127,12 @@ def main(job_config: JobConfig): parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, tp=job_config.training.tensor_parallel_degree, - pp=job_config.training.pipeline_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) init_distributed(job_config) world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -144,6 +150,10 @@ def main(job_config: JobConfig): dp_rank = dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -201,13 +211,26 @@ def loss_fn(pred, labels): # obtain the peak flops of bf16 type for MFU calculation gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - # apply PT-D parallelisms and activation checkpointing + if parallel_dims.pp_enabled: + stage, model = models_pipelining_fns[model_name]( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + + # apply PT-D DP/TP parallelisms and activation checkpointing model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # allocate sharded model on GPU and initialize weights via DTensor + model.to_empty(device="cuda") - model.init_weights() + + if parallel_dims.pp_enabled: + pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn) + else: + # If PP is enabled, we can't rely on init_weights, because some layers are missing. + # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation. + + # allocate sharded model on GPU and initialize weights via DTensor + model.init_weights() gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -257,7 +280,13 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint.load() + checkpoint_loaded = checkpoint.load() + + if parallel_dims.pp_enabled and not checkpoint_loaded: + raise RuntimeError( + "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " + "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" + ) # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. @@ -299,14 +328,33 @@ def loss_fn(pred, labels): input_ids = input_ids.cuda() labels = labels.cuda() - optimizer.zero_grad() - # forward / backward - with loss_parallel_ctx(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - loss.backward() + if parallel_dims.pp_enabled: + # pipeline parallel forward / backward inside step() call + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + + with loss_parallel_ctx(): + if pp_mesh.get_local_rank() == 0: + pp_schedule.step(input_ids) + elif is_last_stage: + losses = [] + pp_schedule.step(target=labels, losses=losses) + else: + pp_schedule.step() + + # accumulate losses across pipeline microbatches + loss = ( + torch.mean(torch.stack(losses)) + if is_last_stage + else torch.Tensor([-1.0]) + ) + else: + # Non-PP forward / backward + with loss_parallel_ctx(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + loss.backward() # clip gradients torch.nn.utils.clip_grad_norm_( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 4541fec7b..009348b5c 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -36,11 +36,13 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) +[experimental] +pipeline_parallel_degree = 1 + [checkpoint] enable_checkpoint = false folder = "checkpoint"