From 970dd636bb53676415831f91a2d49645912bbeb3 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 5 Mar 2024 21:52:06 -0800 Subject: [PATCH] enable loss parallel in SP ghstack-source-id: a0c8b4454f75ad1cd9824ac89a1df0182f6a7d8c Pull Request resolved: https://github.com/pytorch/torchtrain/pull/112 --- torchtrain/config_manager.py | 8 +++++++- torchtrain/parallelisms/__init__.py | 5 +++++ torchtrain/parallelisms/parallelize_llama.py | 6 ++++-- train.py | 14 ++++++++------ 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index 01ad47f9..41d7007b 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -185,7 +185,13 @@ def init_args_from_command_line( "--training.sequence_parallel_degree", type=int, default=1, - help="Sequence Parallelism degree. 1 means disabled.", + help="Sequence Parallelism degree. 1 means disabled.", + ) + parser.add_argument( + "--training.enable_loss_parallel", + default=True, + action="store_true", + help="whether to enable loss parallel when sequence parallel is enabled", ) parser.add_argument( "--training.pipeline_parallel_degree", diff --git a/torchtrain/parallelisms/__init__.py b/torchtrain/parallelisms/__init__.py index 1c9a5641..fdcd938d 100644 --- a/torchtrain/parallelisms/__init__.py +++ b/torchtrain/parallelisms/__init__.py @@ -23,6 +23,7 @@ class ParallelDims: sp: int pp: int world_size: int + enable_loss_parallel: bool def __post_init__(self): self._validate() @@ -63,6 +64,10 @@ def sp_enabled(self): def pp_enabled(self): return self.pp > 1 + @property + def loss_parallel_enabled(self): + return self.sp > 1 and self.enable_loss_parallel + @cached_property def model_parallel_size(self): return self.sp * self.pp diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 3f7a236d..34252b6b 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -135,7 +135,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # First: # 1. parallelize the first embedding and the last linear proj layer # 2. shard the first layer of transformer block - # TODO: enable loss parallel once it's ready model = parallelize_module( model, tp_mesh, @@ -145,7 +144,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ), "output": ColwiseParallel( input_layouts=Shard(0), - output_layouts=Replicate(), + output_layouts=Shard(-1) + if parallel_dims.loss_parallel_enabled + else Replicate(), + use_local_output=not parallel_dims.loss_parallel_enabled, ), "layers.0": PrepareModuleInput( input_layouts=(Replicate(), None), diff --git a/train.py b/train.py index 33f9dc4e..80056ce1 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import contextlib import os from dataclasses import dataclass, field @@ -14,6 +15,7 @@ import torch.nn.functional as F from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.tensor.parallel import loss_parallel from torchtrain.checkpoint import CheckpointManager, IntervalType from torchtrain.config_manager import JobConfig @@ -92,6 +94,7 @@ def main(job_config: JobConfig): sp=job_config.training.sequence_parallel_degree, pp=job_config.training.pipeline_parallel_degree, world_size=world_size, + enable_loss_parallel=job_config.training.enable_loss_parallel, ) world_mesh = parallel_dims.build_mesh(device_type="cuda") rank0_log(f"Starting job: {job_config.job.description}") @@ -216,13 +219,12 @@ def main(job_config: JobConfig): start_timer.record() pred = model(input_ids) - tok_loss = F.cross_entropy( - pred.flatten(0, 1), labels.flatten(0, 1), reduction="none" - ) - loss = tok_loss.mean() - # backward on scaled loss to create scaled gradients - scaler.scale(loss).backward() + with loss_parallel() if parallel_dims.loss_parallel_enabled else contextlib.nullcontext(): + loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) + + # backward on scaled loss to create scaled gradients + scaler.scale(loss).backward() # clip gradients (after unscaling gradients of the optimizer's params) scaler.unscale_(optimizer)