Skip to content

Commit

Permalink
modify config help, update condition logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Oct 25, 2024
1 parent d6840dd commit ab1e258
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
12 changes: 10 additions & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,10 @@ def build_test_list():
[
[
"--training.compile",
"--training.enable_cpu_offload True",
],
],
"1D compile",
"1d_compile",
ngpu=2,
),
OverrideDefinitions(
[
Expand Down Expand Up @@ -353,6 +351,16 @@ def build_test_list():
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.enable_cpu_offload True",
],
],
"Enable CPU Offload",
"enable_cpu_offload",
ngpu=2,
),
]
return integration_tests_flavors

Expand Down
4 changes: 1 addition & 3 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,7 @@ def __init__(self):
type=bool,
default=False,
help="""
The `enable_cpu_offload` argument specifies whether to have offloading policy
for FSDP. If True, CPU offload of parameters, gradients, and optimizer states
will be supported.""",
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
)
self.parser.add_argument(
"--training.tensor_parallel_degree",
Expand Down
9 changes: 3 additions & 6 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,9 @@ def init_weights(
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
if buffer_device is not None:
with torch.device(buffer_device):
self.freqs_cis = self._precompute_freqs_cis()
else:
with torch.device(self.freqs_cis.device):
self.freqs_cis = self._precompute_freqs_cis()
buffer_device = buffer_device or self.freqs_cis.device
with torch.device(buffer_device):
self.freqs_cis = self._precompute_freqs_cis()
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
for layer in self.layers.values():
Expand Down

0 comments on commit ab1e258

Please sign in to comment.