Skip to content

Commit

Permalink
enable loss parallel in SP
Browse files Browse the repository at this point in the history
ghstack-source-id: a0c8b4454f75ad1cd9824ac89a1df0182f6a7d8c
Pull Request resolved: #112
  • Loading branch information
tianyu-l committed Mar 6, 2024
1 parent 8635d74 commit 970dd63
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
8 changes: 7 additions & 1 deletion torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,13 @@ def init_args_from_command_line(
"--training.sequence_parallel_degree",
type=int,
default=1,
help="Sequence Parallelism degree. 1 means disabled.",
help="Sequence Parallelism degree. 1 means disabled.",
)
parser.add_argument(
"--training.enable_loss_parallel",
default=True,
action="store_true",
help="whether to enable loss parallel when sequence parallel is enabled",
)
parser.add_argument(
"--training.pipeline_parallel_degree",
Expand Down
5 changes: 5 additions & 0 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ParallelDims:
sp: int
pp: int
world_size: int
enable_loss_parallel: bool

def __post_init__(self):
self._validate()
Expand Down Expand Up @@ -63,6 +64,10 @@ def sp_enabled(self):
def pp_enabled(self):
return self.pp > 1

@property
def loss_parallel_enabled(self):
return self.sp > 1 and self.enable_loss_parallel

@cached_property
def model_parallel_size(self):
return self.sp * self.pp
6 changes: 4 additions & 2 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
# TODO: enable loss parallel once it's ready
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -145,7 +144,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
"output": ColwiseParallel(
input_layouts=Shard(0),
output_layouts=Replicate(),
output_layouts=Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate(),
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
Expand Down
14 changes: 8 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import contextlib
import os

from dataclasses import dataclass, field
Expand All @@ -14,6 +15,7 @@
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.tensor.parallel import loss_parallel

from torchtrain.checkpoint import CheckpointManager, IntervalType
from torchtrain.config_manager import JobConfig
Expand Down Expand Up @@ -92,6 +94,7 @@ def main(job_config: JobConfig):
sp=job_config.training.sequence_parallel_degree,
pp=job_config.training.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
world_mesh = parallel_dims.build_mesh(device_type="cuda")
rank0_log(f"Starting job: {job_config.job.description}")
Expand Down Expand Up @@ -216,13 +219,12 @@ def main(job_config: JobConfig):
start_timer.record()

pred = model(input_ids)
tok_loss = F.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1), reduction="none"
)
loss = tok_loss.mean()

# backward on scaled loss to create scaled gradients
scaler.scale(loss).backward()
with loss_parallel() if parallel_dims.loss_parallel_enabled else contextlib.nullcontext():
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))

# backward on scaled loss to create scaled gradients
scaler.scale(loss).backward()

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
Expand Down

0 comments on commit 970dd63

Please sign in to comment.