From b6834677a07c62b5cc6663dbefb297574e0e73d1 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 4 Mar 2024 22:04:28 -0800 Subject: [PATCH] enable loss parallel in SP [ghstack-poisoned] --- torchtrain/config_manager.py | 8 +++++++- torchtrain/parallelisms/parallelize_llama.py | 10 ++++++---- train.py | 14 +++++++++----- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index be007979..77d8593f 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -185,7 +185,13 @@ def init_args_from_command_line( "--training.sequence_parallel_degree", type=int, default=1, - help="Sequence Parallelism degree. 1 means disabled.", + help="Sequence Parallelism degree. 1 means disabled.", + ) + parser.add_argument( + "--training.enable_loss_parallel", + default=True, + action="store_true", + help="whether to enable loss parallel when sequence parallel is enabled", ) parser.add_argument( "--training.pipeline_parallel_degree", diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index d11fac9f..cc82118e 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -42,7 +42,7 @@ def distribute_rmsnorm(module, device_mesh): # temp sharding API until PTD API is added - def prepare_input_fn(inputs, device_mesh): + def prepare_input_fn(mod, inputs, device_mesh): if isinstance(inputs[0], DTensor): return inputs elif isinstance(inputs[0], torch.Tensor): @@ -134,7 +134,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # First: # 1. parallelize the first embedding and the last linear proj layer # 2. shard the first layer of transformer block - # TODO: enable loss parallel once it's ready model = parallelize_module( model, tp_mesh, @@ -144,7 +143,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ), "output": ColwiseParallel( input_layouts=Shard(0), - output_layouts=Replicate(), + output_layouts=Shard(-1) + if job_config.training.enable_loss_parallel + else Replicate(), + use_local_output=not job_config.training.enable_loss_parallel, ), "layers.0": PrepareModuleInput( input_layouts=(Replicate(), None), @@ -212,7 +214,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): with enable_wrap(wrapper_cls=FSDP, **fsdp_config): for layer_id, transformer_block in enumerate(model.layers): - # apply selective AC + # apply AC/selective AC transformer_block = checkpoint_wrapper( transformer_block, job_config.training.enable_selective_ac ) diff --git a/train.py b/train.py index 9c8e2f7b..269ac522 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import contextlib import os from dataclasses import dataclass, field @@ -14,6 +15,7 @@ import torch.nn.functional as F from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.tensor.parallel import loss_parallel from torchtrain.checkpoint import CheckpointManager, IntervalType from torchtrain.config_manager import JobConfig @@ -214,13 +216,15 @@ def main(job_config: JobConfig): start_timer.record() pred = model(input_ids) - tok_loss = F.cross_entropy( - pred.flatten(0, 1), labels.flatten(0, 1), reduction="none" + + loss_parallel_enabled = ( + parallel_dims.sp_enabled and job_config.training.enable_loss_parallel ) - loss = tok_loss.mean() + with loss_parallel() if loss_parallel_enabled else contextlib.nullcontext(): + loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) - # backward on scaled loss to create scaled gradients - scaler.scale(loss).backward() + # backward on scaled loss to create scaled gradients + scaler.scale(loss).backward() # clip gradients (after unscaling gradients of the optimizer's params) scaler.unscale_(optimizer)