diff --git a/test_runner.py b/test_runner.py index 61031d74..d98b49b9 100755 --- a/test_runner.py +++ b/test_runner.py @@ -351,6 +351,17 @@ def build_test_list(): "fsdp2_mem_tracker", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--experimental.pipeline_parallel_degree 2", + "--training.enable_cpu_offload True", + ], + ], + "Enable CPU Offload with PP", + "enable_cpu_offload+PP", + ngpu=4, + ), ] return integration_tests_flavors diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 7ef1a9b2..defc010e 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -253,6 +253,13 @@ def __init__(self): can be negative. 1 means disabled.""", ) + self.parser.add_argument( + "--training.enable_cpu_offload", + type=bool, + default=False, + help=""" + Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""", + ) self.parser.add_argument( "--training.tensor_parallel_degree", type=int, diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 315430cf..a3bae18a 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -379,7 +379,10 @@ def __init__(self, model_args: ModelArgs): self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) self.init_weights() - def init_weights(self): + def init_weights( + self, + buffer_device: Optional[torch.device] = None, + ): """ [Note: On ``init_weights`` vs. ``reset_parameters``] Modules may define ``reset_parameters`` to initialize parameter values. @@ -391,7 +394,8 @@ def init_weights(self): ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - with torch.device(self.freqs_cis.device): + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index f5b20ebc..ed23936b 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -13,7 +13,11 @@ import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + fully_shard, + MixedPrecisionPolicy, +) from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -100,6 +104,7 @@ def parallelize_llama( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], tp_enabled=parallel_dims.tp_enabled, pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, ) if parallel_dims.dp_replicate_enabled: @@ -315,12 +320,15 @@ def apply_fsdp( reduce_dtype: torch.dtype, tp_enabled: bool, pp_enabled: bool, + cpu_offload: bool = False, ): """ Apply data parallelism to the model. FSDP2 is used here. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() # TODO: remove this check once PyTorch 2.5 is released. We can safely assume # that users won't use a nightly build which is older than 20240809 by then. diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 79e1073a..8456fdb3 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -174,8 +174,12 @@ def init_distributed(job_config): # such as those in tensor parallelism os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + backend = "nccl" + if job_config.training.enable_cpu_offload: + backend = "cuda:nccl,cpu:gloo" torch.distributed.init_process_group( - "nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds) + backend=backend, + timeout=timedelta(seconds=job_config.comm.init_timeout_seconds), ) diff --git a/train.py b/train.py index 31f66537..bc04dad0 100644 --- a/train.py +++ b/train.py @@ -138,6 +138,17 @@ def loss_fn(pred, labels): if job_config.training.compile: loss_fn = torch.compile(loss_fn) + # move sharded model to CPU/GPU and initialize weights via DTensor + if job_config.checkpoint.create_seed_checkpoint: + init_device = "cpu" + buffer_device = None + elif job_config.training.enable_cpu_offload: + init_device = "cpu" + buffer_device = "cuda" + else: + init_device = "cuda" + buffer_device = None + # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel @@ -151,17 +162,14 @@ def loss_fn(pred, labels): for m in model_parts: # apply SPMD-style PT-D techniques models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) - m.to_empty(device="cuda") - m.init_weights() + m.to_empty(device=init_device) + m.init_weights(buffer_device=buffer_device) m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) - - # move sharded model to CPU/GPU and initialize weights via DTensor - init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" model.to_empty(device=init_device) - model.init_weights() + model.init_weights(buffer_device=buffer_device) model.train() model_parts = [model]