Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable loss parallel in SP #112

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,13 @@ def init_args_from_command_line(
"--training.sequence_parallel_degree",
type=int,
default=1,
help="Sequence Parallelism degree. 1 means disabled.",
help="Sequence Parallelism degree. 1 means disabled.",
)
parser.add_argument(
"--training.enable_loss_parallel",
default=True,
action="store_true",
help="whether to enable loss parallel when sequence parallel is enabled",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any downside to enabling loss parallel?

Copy link
Contributor Author

@tianyu-l tianyu-l Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my consideration:

  1. Our implementation of loss parallel decomposes large aten ops (log_softmax, nll_loss) to a series torch ops, which creates some computation overhead.
  2. loss parallel in general have multiple collectives (three "small" all-reduces) instead of one "big" all-gather. Small vs. big depends on the vocab size, although normal vocab size should always be large enough to justify the memory & communication gain.

For these reasons, I feel we can keep the option but turned on by default. I'm OK with removing this option as well.

)
parser.add_argument(
"--training.pipeline_parallel_degree",
Expand Down
5 changes: 5 additions & 0 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ParallelDims:
sp: int
pp: int
world_size: int
enable_loss_parallel: bool

def __post_init__(self):
self._validate()
Expand Down Expand Up @@ -63,6 +64,10 @@ def sp_enabled(self):
def pp_enabled(self):
return self.pp > 1

@property
def loss_parallel_enabled(self):
return self.sp > 1 and self.enable_loss_parallel

@cached_property
def model_parallel_size(self):
return self.sp * self.pp
6 changes: 4 additions & 2 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
# TODO: enable loss parallel once it's ready
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -145,7 +144,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
"output": ColwiseParallel(
input_layouts=Shard(0),
output_layouts=Replicate(),
output_layouts=Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate(),
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
Expand Down
14 changes: 8 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import contextlib
import os

from dataclasses import dataclass, field
Expand All @@ -14,6 +15,7 @@
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.tensor.parallel import loss_parallel

from torchtrain.checkpoint import CheckpointManager, IntervalType
from torchtrain.config_manager import JobConfig
Expand Down Expand Up @@ -92,6 +94,7 @@ def main(job_config: JobConfig):
sp=job_config.training.sequence_parallel_degree,
pp=job_config.training.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
world_mesh = parallel_dims.build_mesh(device_type="cuda")
rank0_log(f"Starting job: {job_config.job.description}")
Expand Down Expand Up @@ -216,13 +219,12 @@ def main(job_config: JobConfig):
start_timer.record()

pred = model(input_ids)
tok_loss = F.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1), reduction="none"
)
loss = tok_loss.mean()

# backward on scaled loss to create scaled gradients
scaler.scale(loss).backward()
with loss_parallel() if parallel_dims.loss_parallel_enabled else contextlib.nullcontext():
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))

# backward on scaled loss to create scaled gradients
scaler.scale(loss).backward()

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
Expand Down
Loading