Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can I load from non-FSDP optimizer state with FSDP2? #765

Closed
syncdoth opened this issue Dec 31, 2024 · 3 comments
Closed

Can I load from non-FSDP optimizer state with FSDP2? #765

syncdoth opened this issue Dec 31, 2024 · 3 comments
Labels
question Further information is requested

Comments

@syncdoth
Copy link

I have been running training on a different framework with FSDP1, where I saved the states with FULL_STATE_DICT - leading to optimizer states that are in a normal torch.save format. I'd love to resume from this checkpoint - is this currently supported by FSDP2 / DCP? When I naively try dcp.load it resulted in a shard index out of range error.

@syncdoth syncdoth changed the title Can I load from non-FSDP optimizer state? Can I load from non-FSDP optimizer state with FSDP2? Dec 31, 2024
@awgu
Copy link
Contributor

awgu commented Dec 31, 2024

There should be a way to load it with DCP cc: @fegin @mori360 .

Full state dicts are state dicts without FSDP sharding. To make them loadable to FSDP2, you just need to iterate over the tensors in the optimizer state that should match the parameter sharding and shard them on dim-0 with DTensor. This can be done with some relatively simple code natively, but I will let @fegin or others comment on what the right way to do this with DCP APIs is.

@tianyu-l tianyu-l added the question Further information is requested label Jan 2, 2025
@fegin
Copy link
Contributor

fegin commented Jan 8, 2025

Yes, you can write a script to do the conversion offline -- simply loading the torch.save optimizer state_dict and then call DCP.save. Then the saved checkpoints should be loadable with FSDP2 + DCP. This is a more complicated version: https://github.com/pytorch/torchtitan/blob/main/scripts/convert_llama_to_dcp.py but the idea is the same. In your case, if everything stay the same (e.g., parameter group), simply loading (torch.load) -> DCP.save should work.

@fegin
Copy link
Contributor

fegin commented Jan 28, 2025

@syncdoth I'm going to close the issue. Please let me know if you have any further questions.

@fegin fegin closed this as completed Jan 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants