diff --git a/test_runner.py b/test_runner.py index 4fc95172..1229dcd1 100755 --- a/test_runner.py +++ b/test_runner.py @@ -57,12 +57,10 @@ def build_test_list(): [ [ "--training.compile", - "--training.enable_cpu_offload True", ], ], "1D compile", "1d_compile", - ngpu=2, ), OverrideDefinitions( [ @@ -353,6 +351,16 @@ def build_test_list(): "fsdp2_mem_tracker", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.enable_cpu_offload True", + ], + ], + "Enable CPU Offload", + "enable_cpu_offload", + ngpu=2, + ), ] return integration_tests_flavors diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index c5feec19..defc010e 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -258,9 +258,7 @@ def __init__(self): type=bool, default=False, help=""" - The `enable_cpu_offload` argument specifies whether to have offloading policy - for FSDP. If True, CPU offload of parameters, gradients, and optimizer states - will be supported.""", + Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""", ) self.parser.add_argument( "--training.tensor_parallel_degree", diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 8644629c..a3bae18a 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -394,12 +394,9 @@ def init_weights( ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - if buffer_device is not None: - with torch.device(buffer_device): - self.freqs_cis = self._precompute_freqs_cis() - else: - with torch.device(self.freqs_cis.device): - self.freqs_cis = self._precompute_freqs_cis() + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values():