-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable FSDP2 cpu offloading #624
Conversation
train.py
Outdated
@@ -169,13 +166,15 @@ def loss_fn(pred, labels): | |||
else "cuda" | |||
) | |||
model.to_empty(device=init_device) | |||
model.init_weights() | |||
if job_config.training.enable_cpu_offload: | |||
with torch.device("cuda"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I somehow think the model.init_weights(buffer_device="cuda")
change sounds better. It is straightforward on what we want to achieve, while making minimum change to code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in general. Had some inline comments. In particular, let's figure out if CPU offloading and PP should coexist; if so we should add support for that as well.
train.py
Outdated
init_device = ( | ||
"cpu" | ||
if job_config.checkpoint.create_seed_checkpoint | ||
or job_config.training.enable_cpu_offload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to do the same for the PP case (several lines above)? Or are we assuming if PP is used, CPU offloading is not an option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Enable CPU Offload with PP", | ||
"enable_cpu_offload+PP", | ||
ngpu=4, | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test with pp, could remove pp later if not necessary in the CI test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! thanks!
resolve pytorch#620 Add config: `--training.enable_cpu_offload` Command: `CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh` For non-pp case: <img width="611" alt="Screenshot 2024-10-23 at 1 45 56 PM" src="https://github.com/user-attachments/assets/8692f8a6-c0f3-460e-8eb6-7f7195bed370"> For pp case: <img width="587" alt="cpu offload+pp" src="https://github.com/user-attachments/assets/73e40861-47e2-4845-a41c-4bfea2860109">
resolve #620
Add config:
--training.enable_cpu_offload
Command:
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
For non-pp case:
For pp case: