-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ckpt-rewr] Get Model State Dict Util Function #3250
Conversation
…into get-model-sd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just pulling out existing code into a helper fn?
@eracah it would be great to get slightly more description so I know what parts to carefully read over and what is less important
It's detailed in the design doc, but I can copy and paste it in if you want |
It's easier when reviewing PRs to either link to the right part of design doc or copy paste description |
Ok added description |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM! just a few minor nits that should be quick to clean up. Also looks like tests are failing, I think because torch gating is a bit off
Co-authored-by: Mihir Patel <[email protected]>
Co-authored-by: Mihir Patel <[email protected]>
…into get-model-sd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few nits, one larger comment/question.
I'm guessing there are a lot of tests that could be simplified by using the functionality that this PR adds. Is that true? If so, is it worth trying to do at least some of that as part of this PR? It would also implicitly test the functionality being added more, since those would be real uses cases.
…into get-model-sd
…into get-model-sd
Yes in theory, we would want to do that. However, in this PR we are just adding the API to be used in a script; we aren't actually adding this code to be used in Trainer. So until we swap out state.py state dict generation code for this one, we don't need to change those other tests or add any E2E tests. |
What does this PR do?
Adds an API for extracting model state dict from a model object.
State dict generation is a necessary operation before the save AND load of a checkpoint.
Currently in composer it is coupled with the State, and not very readable, hard to extend, hard to test, and hard for users to harness to do custom things. As such, we present a function to generate state_dict for the model decoupled from State as a standalone function. By making an explicit function for the model, it’s easier to test because we have a standalone function (we don’t have to make a dummy State function). Moreover, it’s easier to save each state dict as a separate file Also, an advanced user can just call these functions themselves if they have a custom, advanced script or callback.
This state dict generation function enables:
These are all options that will be useful for save and load. Because save and load require state dict generation, we need these options in state dict generation as well
GRT-2903