diff --git a/.github/workflows/unit_test_4gpu.yaml b/.github/workflows/unit_test_4gpu.yaml index 0088bb3e..76ac024b 100644 --- a/.github/workflows/unit_test_4gpu.yaml +++ b/.github/workflows/unit_test_4gpu.yaml @@ -36,6 +36,7 @@ jobs: pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 python -m pip install -r requirements.txt python -m pip install -r dev-requirements.txt + python -m pip install git+https://github.com/pytorch/pippy - name: Run test_runner.py run: python ./test_runner.py - name: Upload Coverage to Codecov diff --git a/test_runner.py b/test_runner.py index 80d75ad8..3ece3436 100755 --- a/test_runner.py +++ b/test_runner.py @@ -26,6 +26,8 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + requires_seed_checkpoint: bool = False + ngpu: int = 4 CONFIG_DIR = "./train_configs" @@ -85,25 +87,72 @@ class OverrideDefinitions: ], "Checkpoint Integration Test - Save Model Weights Only bf16", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--checkpoint.folder {test_checkpoint_dir}_pp", + "--training.pipeline_parallel_degree 2", + "--training.data_parallel_degree 1", + "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + ], + ], + "PP 1D test", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--checkpoint.folder {test_checkpoint_dir}_pp_dp", + "--training.pipeline_parallel_degree 2", + "--training.data_parallel_degree 2", + "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue + ], + ], + "PP+DP 2D test", + requires_seed_checkpoint=True, + ), ] +def _run_cmd(cmd): + return subprocess.run( + [cmd], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + + def run_test(test_flavor: OverrideDefinitions, full_path: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: - cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh" + + cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" if override_arg: cmd += " " + " ".join(override_arg) print( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) - result = subprocess.run( - [cmd], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - ) + + if test_flavor.requires_seed_checkpoint: + checkpoint_folder_arg = None + for arg in override_arg: + if "--checkpoint.folder" in arg: + checkpoint_folder_arg = arg + assert ( + checkpoint_folder_arg is not None + ), "Can't use seed checkpoint if folder is not specified" + print("Creating seed checkpoint") + result = _run_cmd( + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {checkpoint_folder_arg}" + ) + print(result.stdout) + + result = _run_cmd(cmd) print(result.stdout) if result.returncode != 0: raise Exception( diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index fca776a7..27f1f28e 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -12,6 +12,15 @@ import torch +# TODO(whc) this can be removed after pippy migration into pytorch core is complete. +try: + from pippy import pipeline, SplitPoint +except ImportError as exc: + raise ImportError( + "pippy is not installed. Please install it to use pipeline parallelism. " + "`pip install git+https://github.com/pytorch/pippy`" + ) from exc + from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -129,15 +138,45 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel +def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig): + assert ( + parallel_dims.pp_enabled + ), "can't apply pipeline parallelism if it is not enabled" + + if job_config.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + ) + pp_mesh = world_mesh["pp"] + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + split_spec = { + f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, parallel_dims.pp) + } + # Get example input + input_shape = (job_config.training.batch_size, job_config.training.seq_len) + input_ids = torch.randint( + model.vocab_size, input_shape, dtype=torch.int64, device="meta" + ) + + # Create a pipeline representation from the model + pipe = pipeline( + model, parallel_dims.pp, example_args=(input_ids,), split_spec=split_spec + ) + model = pipe.get_stage_module(stage_idx) + return model, pipe.pipe_info + + def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ - Apply parallelisms and activation checkpointing to the model. + Apply SPMD parallelisms and activation checkpointing to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") if parallel_dims.tp_enabled: if job_config.model.norm_type == "fused_rmsnorm": @@ -215,24 +254,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names # TODO: Expose `reduce_dtype` as a config option. mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + # TODO(whc) need to fix PP + FSDP-mixed-precision + # tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs + # param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + param_dtype=torch.float32, + reduce_dtype=torch.float32, ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - for layer_id, transformer_block in enumerate(model.layers): + for layer_name, transformer_block in model.layers.named_children(): if job_config.activation_checkpoint.mode in ("full", "selective"): transformer_block = checkpoint_wrapper( transformer_block, job_config.activation_checkpoint ) # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately - reshard_after_forward = layer_id < len(model.layers) - 1 + # reshard_after_forward = layer_id < len(model.layers) - 1 + # TODO(whc) need to fix correctly handle layer-ids on pp-split module + reshard_after_forward = True fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) - model.layers[layer_id] = transformer_block + model.layers.add_module(layer_name, transformer_block) + model = fully_shard(model, **fsdp_config) if ac_mode in ("full", "selective"): logger.info(f"Applied {ac_mode} activation checkpointing to the model") diff --git a/train.py b/train.py index 5fee20c0..a9685482 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,17 @@ import torch import torch.nn.functional as F + +# TODO(whc) this can be removed after pippy migration into pytorch core is complete. +try: + from pippy import ScheduleGPipe + from pippy.PipelineStage import _PipelineStage +except ImportError as exc: + raise ImportError( + "pippy is not installed. Please install it to use pipeline parallelism. " + "`pip install git+https://github.com/pytorch/pippy`" + ) from exc + from torch.distributed import destroy_process_group from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record @@ -126,7 +137,8 @@ def main(job_config: JobConfig): world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) init_distributed(job_config) world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -144,6 +156,15 @@ def main(job_config: JobConfig): dp_rank = dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + pp_degree = pp_mesh.size() + pp_rank = pp_mesh.get_local_rank() + + else: + pp_degree, pp_rank = 1, 0 + data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -201,13 +222,44 @@ def loss_fn(pred, labels): # obtain the peak flops of bf16 type for MFU calculation gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - # apply PT-D parallelisms and activation checkpointing + if parallel_dims.pp_enabled: + # TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern` + from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism + + model, pipe_info = apply_pipeline_parallelism( + model, world_mesh, parallel_dims, job_config + ) + + # apply PT-D DP/TP parallelisms and activation checkpointing model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # allocate sharded model on GPU and initialize weights via DTensor + model.to_empty(device="cuda") - model.init_weights() + + # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if + # there are virtual stages + if parallel_dims.pp_enabled: + stage = _PipelineStage( + stage_module=model, + stage_index=pp_rank, + pipe_info=pipe_info, + device=device, + group=pp_mesh.get_group(), + ) + pp_schedule = ScheduleGPipe( + stage, + n_microbatches=parallel_dims.pp, + loss_fn=loss_fn, + ) + else: + # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint + # and loading it to get initialization values. This is becuase the init_weights functions are written assuming + # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash + # becuase it can't find "embedding" layer, for example. + + # allocate sharded model on GPU and initialize weights via DTensor + model.init_weights() gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -219,7 +271,6 @@ def loss_fn(pred, labels): # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) - metric_logger = build_metric_logger(job_config) # torch.compile model for improved performance @@ -257,7 +308,13 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint.load() + checkpoint_loaded = checkpoint.load() + + if parallel_dims.pp_enabled and not checkpoint_loaded: + raise RuntimeError( + "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " + "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" + ) # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. @@ -299,14 +356,33 @@ def loss_fn(pred, labels): input_ids = input_ids.cuda() labels = labels.cuda() - optimizer.zero_grad() - # forward / backward - with loss_parallel_ctx(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - loss.backward() + if parallel_dims.pp_enabled: + # pipeline parallel forward / backward inside step() call + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + + with loss_parallel_ctx(): + if pp_mesh.get_local_rank() == 0: + pp_schedule.step(input_ids) + elif is_last_stage: + losses = [] + pp_schedule.step(target=labels, losses=losses) + else: + schedule.step() + + # accumulate losses across pipeline microbatches + loss = ( + torch.mean(torch.stack(losses)) + if is_last_stage + else torch.Tensor([-1.0]) + ) + else: + # Non-PP forward / backward + with loss_parallel_ctx(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + loss.backward() # clip gradients torch.nn.utils.clip_grad_norm_(