Skip to content

Commit

Permalink
Update on "enable loss parallel in SP"
Browse files Browse the repository at this point in the history
In this PR, loss parallel is enabled by default when sequence parallel is enabled.

Below are some empirical results when training for 10 steps, with `sequence_parallel_degree = 4`.
debug_model, with loss parallel disabled
```
Average iter time: 0.1862 seconds
Peak Memory: Reserved 12.9%, Alloc 9.97%, Active: 10.31%
```
debug_model, with loss parallel enabled
```
Average iter time: 0.1055 seconds
Peak Memory: Reserved 3.64%, Alloc 3.18%, Active: 3.25%
```

llama 7B, with loss parallel disabled
```
Average iter time: 6.4379 seconds
Peak Memory: Reserved 34.24%, Alloc 23.64%, Active: 25.89%
```
llama 7B, with loss parallel enabled
```
Average iter time: 6.4778 seconds
Peak Memory: Reserved 27.11%, Alloc 20.51%, Active: 20.51%
```



[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Mar 6, 2024
2 parents b683467 + a725c01 commit 8cf7f9c
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 20 deletions.
8 changes: 8 additions & 0 deletions test/test_job_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import tempfile

import pytest
from torchtrain.config_manager import JobConfig

Expand All @@ -20,3 +22,9 @@ def test_job_file_does_not_exist(self):
with pytest.raises(FileNotFoundError):
config = JobConfig()
config.parse_args(["--job.config_file", "ohno.toml"])

def test_empty_config_file(self):
with tempfile.NamedTemporaryFile() as fp:
config = JobConfig()
config.parse_args(["--job.config_file", fp.name])
assert config.job.description
1 change: 0 additions & 1 deletion torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def init_args_from_command_line(
)
parser.add_argument(
"--training.enable_selective_ac",
default=False,
action="store_true",
help="whether to enable selective activation checkpointing",
)
Expand Down
5 changes: 5 additions & 0 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ParallelDims:
sp: int
pp: int
world_size: int
enable_loss_parallel: bool

def __post_init__(self):
self._validate()
Expand Down Expand Up @@ -63,6 +64,10 @@ def sp_enabled(self):
def pp_enabled(self):
return self.pp > 1

@property
def loss_parallel_enabled(self):
return self.sp > 1 and self.enable_loss_parallel

@cached_property
def model_parallel_size(self):
return self.sp * self.pp
18 changes: 11 additions & 7 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,16 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
NOTE: the model passed in preferrablably shoule be a meta device model,
otherwise the model needs to be small enough on GPU or can fit into CPU.
# TODO: apply SP
"""
# apply PTD parallelisms
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")

# First we apply Sequence Parallelism if it's enabled
if parallel_dims.sp_enabled:
# First we apply Sequence Parallelism if it's enabled
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
tp_mesh = world_mesh["sp"]
sp_degree = job_config.training.sequence_parallel_degree

# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
Expand All @@ -144,9 +145,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"output": ColwiseParallel(
input_layouts=Shard(0),
output_layouts=Shard(-1)
if job_config.training.enable_loss_parallel
if parallel_dims.loss_parallel_enabled
else Replicate(),
use_local_output=not job_config.training.enable_loss_parallel,
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
Expand All @@ -156,6 +157,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
},
)

# shard the RMSNorm layer before last linear proj layer
distribute_rmsnorm(model.norm, tp_mesh)

# apply sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
Expand Down Expand Up @@ -194,8 +198,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
rank0_log("Applied Sequence Parallelism to the model...")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
dp_mesh = world_mesh["dp"]

fsdp_config = {
"mixed_precision": MixedPrecision(
Expand Down Expand Up @@ -227,6 +230,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

rank0_log("Applied FSDP to the model...")
else:
meta_to_real_init_fn(model)
model.cuda()

# we have now moved from meta to device,
Expand Down
27 changes: 15 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def main(job_config: JobConfig):
sp=job_config.training.sequence_parallel_degree,
pp=job_config.training.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
world_mesh = parallel_dims.build_mesh(device_type="cuda")
rank0_log(f"Starting job: {job_config.job.description}")
Expand All @@ -104,11 +105,13 @@ def main(job_config: JobConfig):
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build dataloader
# need dp world size and rank
dp_mesh = world_mesh["dp"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
build_dataloader_fn = dataloader_fn[job_config.training.dataset]
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0
data_loader = build_dataloader_fn(
tokenizer,
job_config.training.batch_size,
Expand Down Expand Up @@ -217,10 +220,7 @@ def main(job_config: JobConfig):

pred = model(input_ids)

loss_parallel_enabled = (
parallel_dims.sp_enabled and job_config.training.enable_loss_parallel
)
with loss_parallel() if loss_parallel_enabled else contextlib.nullcontext():
with loss_parallel() if parallel_dims.loss_parallel_enabled else contextlib.nullcontext():
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))

# backward on scaled loss to create scaled gradients
Expand Down Expand Up @@ -259,10 +259,13 @@ def main(job_config: JobConfig):
np.mean(losses_since_last_log),
np.max(losses_since_last_log),
)
global_avg_loss, global_max_loss = (
dist_mean(avg_loss, dp_mesh),
dist_max(max_loss, dp_mesh),
)
if parallel_dims.dp_enabled:
global_avg_loss, global_max_loss = (
dist_mean(avg_loss, dp_mesh),
dist_max(max_loss, dp_mesh),
)
else:
global_avg_loss, global_max_loss = avg_loss, max_loss

time_delta = timer() - time_last_log
wps = nwords_since_last_log / (
Expand Down

0 comments on commit 8cf7f9c

Please sign in to comment.