Skip to content

Commit

Permalink
Enable FSDP2 cpu offloading (#624)
Browse files Browse the repository at this point in the history
resolve #620 
Add config: `--training.enable_cpu_offload`

Command: `CONFIG_FILE="./train_configs/llama3_8b.toml"
./run_llama_train.sh`

For non-pp case:
<img width="611" alt="Screenshot 2024-10-23 at 1 45 56 PM"
src="https://github.com/user-attachments/assets/8692f8a6-c0f3-460e-8eb6-7f7195bed370">

For pp case:
<img width="587" alt="cpu offload+pp"
src="https://github.com/user-attachments/assets/73e40861-47e2-4845-a41c-4bfea2860109">
  • Loading branch information
mori360 authored Oct 28, 2024
1 parent 603889a commit 193ce98
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 10 deletions.
11 changes: 11 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,17 @@ def build_test_list():
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
[
[
"--experimental.pipeline_parallel_degree 2",
"--training.enable_cpu_offload True",
],
],
"Enable CPU Offload with PP",
"enable_cpu_offload+PP",
ngpu=4,
),
]
return integration_tests_flavors

Expand Down
7 changes: 7 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ def __init__(self):
can be negative.
1 means disabled.""",
)
self.parser.add_argument(
"--training.enable_cpu_offload",
type=bool,
default=False,
help="""
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
)
self.parser.add_argument(
"--training.tensor_parallel_degree",
type=int,
Expand Down
8 changes: 6 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,10 @@ def __init__(self, model_args: ModelArgs):
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

def init_weights(self):
def init_weights(
self,
buffer_device: Optional[torch.device] = None,
):
"""
[Note: On ``init_weights`` vs. ``reset_parameters``]
Modules may define ``reset_parameters`` to initialize parameter values.
Expand All @@ -391,7 +394,8 @@ def init_weights(self):
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
with torch.device(self.freqs_cis.device):
buffer_device = buffer_device or self.freqs_cis.device
with torch.device(buffer_device):
self.freqs_cis = self._precompute_freqs_cis()
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
Expand Down
10 changes: 9 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
import torch.nn as nn

from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
fully_shard,
MixedPrecisionPolicy,
)
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -100,6 +104,7 @@ def parallelize_llama(
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
)

if parallel_dims.dp_replicate_enabled:
Expand Down Expand Up @@ -315,12 +320,15 @@ def apply_fsdp(
reduce_dtype: torch.dtype,
tp_enabled: bool,
pp_enabled: bool,
cpu_offload: bool = False,
):
"""
Apply data parallelism to the model. FSDP2 is used here.
"""
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
Expand Down
6 changes: 5 additions & 1 deletion torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,12 @@ def init_distributed(job_config):
# such as those in tensor parallelism
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

backend = "nccl"
if job_config.training.enable_cpu_offload:
backend = "cuda:nccl,cpu:gloo"
torch.distributed.init_process_group(
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
backend=backend,
timeout=timedelta(seconds=job_config.comm.init_timeout_seconds),
)


Expand Down
20 changes: 14 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ def loss_fn(pred, labels):
if job_config.training.compile:
loss_fn = torch.compile(loss_fn)

# move sharded model to CPU/GPU and initialize weights via DTensor
if job_config.checkpoint.create_seed_checkpoint:
init_device = "cpu"
buffer_device = None
elif job_config.training.enable_cpu_offload:
init_device = "cpu"
buffer_device = "cuda"
else:
init_device = "cuda"
buffer_device = None

# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
Expand All @@ -151,17 +162,14 @@ def loss_fn(pred, labels):
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device="cuda")
m.init_weights()
m.to_empty(device=init_device)
m.init_weights(buffer_device=buffer_device)
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
model.init_weights()
model.init_weights(buffer_device=buffer_device)
model.train()

model_parts = [model]
Expand Down

0 comments on commit 193ce98

Please sign in to comment.