Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FSDP2 cpu offloading #624

Merged
merged 17 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,17 @@ def build_test_list():
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
[
[
"--experimental.pipeline_parallel_degree 2",
"--training.enable_cpu_offload True",
],
],
"Enable CPU Offload with PP",
"enable_cpu_offload+PP",
ngpu=4,
),
Copy link
Contributor Author

@mori360 mori360 Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test with pp, could remove pp later if not necessary in the CI test

]
return integration_tests_flavors

Expand Down
7 changes: 7 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ def __init__(self):
can be negative.
1 means disabled.""",
)
self.parser.add_argument(
"--training.enable_cpu_offload",
type=bool,
default=False,
help="""
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
)
self.parser.add_argument(
"--training.tensor_parallel_degree",
type=int,
Expand Down
8 changes: 6 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,10 @@ def __init__(self, model_args: ModelArgs):
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

def init_weights(self):
def init_weights(
self,
buffer_device: Optional[torch.device] = None,
):
"""
[Note: On ``init_weights`` vs. ``reset_parameters``]
Modules may define ``reset_parameters`` to initialize parameter values.
Expand All @@ -391,7 +394,8 @@ def init_weights(self):
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
with torch.device(self.freqs_cis.device):
buffer_device = buffer_device or self.freqs_cis.device
with torch.device(buffer_device):
self.freqs_cis = self._precompute_freqs_cis()
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
Expand Down
10 changes: 9 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
import torch.nn as nn

from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
fully_shard,
MixedPrecisionPolicy,
)
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -100,6 +104,7 @@ def parallelize_llama(
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
)

if parallel_dims.dp_replicate_enabled:
Expand Down Expand Up @@ -315,12 +320,15 @@ def apply_fsdp(
reduce_dtype: torch.dtype,
tp_enabled: bool,
pp_enabled: bool,
cpu_offload: bool = False,
):
"""
Apply data parallelism to the model. FSDP2 is used here.
"""
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
Expand Down
6 changes: 5 additions & 1 deletion torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,12 @@ def init_distributed(job_config):
# such as those in tensor parallelism
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

backend = "nccl"
if job_config.training.enable_cpu_offload:
backend = "cuda:nccl,cpu:gloo"
torch.distributed.init_process_group(
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
backend=backend,
timeout=timedelta(seconds=job_config.comm.init_timeout_seconds),
)


Expand Down
20 changes: 14 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ def loss_fn(pred, labels):
if job_config.training.compile:
loss_fn = torch.compile(loss_fn)

# move sharded model to CPU/GPU and initialize weights via DTensor
if job_config.checkpoint.create_seed_checkpoint:
init_device = "cpu"
buffer_device = None
elif job_config.training.enable_cpu_offload:
init_device = "cpu"
buffer_device = "cuda"
else:
init_device = "cuda"
buffer_device = None

# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
Expand All @@ -151,17 +162,14 @@ def loss_fn(pred, labels):
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device="cuda")
m.init_weights()
m.to_empty(device=init_device)
m.init_weights(buffer_device=buffer_device)
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
model.init_weights()
model.init_weights(buffer_device=buffer_device)
model.train()

model_parts = [model]
Expand Down
Loading