Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable integration test between optimizer-in-backward and pp #793

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,14 +378,22 @@ def build_test_list():
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--training.tensor_parallel_degree=2",
"--experimental.context_parallel_degree=2",
"--training.enable_cpu_offload",
"--optimizer.early_step_in_backward",
],
[
"--training.tensor_parallel_degree=2",
"--experimental.context_parallel_degree=2",
"--training.data_parallel_replicate_degree=2",
"--training.enable_cpu_offload",
"--optimizer.early_step_in_backward",
],
],
"Enable CPU Offload with PP",
"enable_cpu_offload+PP",
ngpu=4,
"Enable CPU Offload, Optimizer in backward with TP, DP, CP",
"cpu_offload+opt_in_bwd+TP+DP+CP",
ngpu=8,
),
OverrideDefinitions(
[
Expand Down
33 changes: 22 additions & 11 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,37 @@ def __init__(
) -> None:
self.optimizers = []
self.model_parts = model_parts
optim_dict = {}
Copy link
Contributor Author

@mori360 mori360 Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collect all optims in optim_dict, avoid bugs with only valid hooks at the last self.model_parts if len(self.model_parts)>1 (for future support of multi schedule pp)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix for optim dict lgtm

for model in self.model_parts:
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optim_dict = {
param: torch.optim.Adam([param], **optimizer_kwargs)
for param in model.parameters()
}
optim_dict.update(
{
param: torch.optim.Adam([param], **optimizer_kwargs)
for param in model.parameters()
}
)
elif name == "AdamW":
optim_dict = {
param: torch.optim.AdamW([param], **optimizer_kwargs)
for param in model.parameters()
}
optim_dict.update(
{
param: torch.optim.AdamW([param], **optimizer_kwargs)
for param in model.parameters()
}
)
else:
raise NotImplementedError(f"Optimizer {name} not added.")

def optim_hook(param) -> None:
optim_dict[param].step()
optim_dict[param].zero_grad()
def optim_hook(param) -> None:
optim_dict[param].step()
optim_dict[param].zero_grad()

for model in self.model_parts:
for param in model.parameters():
if param.requires_grad:
param.register_post_accumulate_grad_hook(optim_hook)

self.optimizers.extend([optim_dict[param] for param in model.parameters()])

self._validate_length(
sum(
len([param for param in model.parameters()])
Expand All @@ -127,6 +134,10 @@ def build_optimizers(
step() and zero_grad() method for all the child optimizers.
"""
optim_in_bwd = job_config.optimizer.early_step_in_backward
if optim_in_bwd and job_config.experimental.pipeline_parallel_degree > 1:
raise NotImplementedError(
"OptimizersInBackwardContainer is not supported with pipeline parallelism"
)
name = job_config.optimizer.name
lr = job_config.optimizer.lr
fused = job_config.optimizer.fused
Expand Down
Loading