Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ckpt-rewr] Get Model State Dict Util Function #3250

Merged
merged 40 commits into from
May 17, 2024
Merged

Conversation

eracah
Copy link
Contributor

@eracah eracah commented May 3, 2024

What does this PR do?

Adds an API for extracting model state dict from a model object.

State dict generation is a necessary operation before the save AND load of a checkpoint.
Currently in composer it is coupled with the State, and not very readable, hard to extend, hard to test, and hard for users to harness to do custom things. As such, we present a function to generate state_dict for the model decoupled from State as a standalone function. By making an explicit function for the model, it’s easier to test because we have a standalone function (we don’t have to make a dummy State function). Moreover, it’s easier to save each state dict as a separate file Also, an advanced user can just call these functions themselves if they have a custom, advanced script or callback.

This state dict generation function enables:

  • generating sharded or full state dicts
  • generating state dicts of different precision
  • specify keys to include
  • sprecify keys to exclude

These are all options that will be useful for save and load. Because save and load require state dict generation, we need these options in state dict generation as well

GRT-2903

@eracah eracah marked this pull request as draft May 3, 2024 04:44
@eracah eracah marked this pull request as ready for review May 11, 2024 00:17
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just pulling out existing code into a helper fn?

@eracah it would be great to get slightly more description so I know what parts to carefully read over and what is less important

@eracah
Copy link
Contributor Author

eracah commented May 13, 2024

Is this just pulling out existing code into a helper fn?

@eracah it would be great to get slightly more description so I know what parts to carefully read over and what is less important

It's detailed in the design doc, but I can copy and paste it in if you want

@mvpatel2000
Copy link
Contributor

Is this just pulling out existing code into a helper fn?
@eracah it would be great to get slightly more description so I know what parts to carefully read over and what is less important

It's detailed in the design doc, but I can copy and paste it in if you want

It's easier when reviewing PRs to either link to the right part of design doc or copy paste description

@eracah
Copy link
Contributor Author

eracah commented May 15, 2024

Is this just pulling out existing code into a helper fn?
@eracah it would be great to get slightly more description so I know what parts to carefully read over and what is less important

It's detailed in the design doc, but I can copy and paste it in if you want

It's easier when reviewing PRs to either link to the right part of design doc or copy paste description

Ok added description

@eracah eracah requested review from bigning and mvpatel2000 May 15, 2024 20:28
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM! just a few minor nits that should be quick to clean up. Also looks like tests are failing, I think because torch gating is a bit off

tests/checkpoint/test_state_dict.py Outdated Show resolved Hide resolved
tests/checkpoint/test_state_dict.py Show resolved Hide resolved
tests/common/models.py Outdated Show resolved Hide resolved
composer/checkpoint/state_dict.py Outdated Show resolved Hide resolved
composer/checkpoint/state_dict.py Outdated Show resolved Hide resolved
composer/checkpoint/state_dict.py Outdated Show resolved Hide resolved
composer/checkpoint/state_dict.py Outdated Show resolved Hide resolved
composer/checkpoint/state_dict.py Outdated Show resolved Hide resolved
composer/checkpoint/state_dict.py Outdated Show resolved Hide resolved
tests/checkpoint/test_state_dict.py Outdated Show resolved Hide resolved
tests/checkpoint/test_state_dict.py Outdated Show resolved Hide resolved
Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few nits, one larger comment/question.

I'm guessing there are a lot of tests that could be simplified by using the functionality that this PR adds. Is that true? If so, is it worth trying to do at least some of that as part of this PR? It would also implicitly test the functionality being added more, since those would be real uses cases.

composer/checkpoint/state_dict.py Show resolved Hide resolved
composer/checkpoint/state_dict.py Outdated Show resolved Hide resolved
composer/checkpoint/state_dict.py Outdated Show resolved Hide resolved
tests/common/compare.py Outdated Show resolved Hide resolved
@eracah
Copy link
Contributor Author

eracah commented May 17, 2024

few nits, one larger comment/question.

I'm guessing there are a lot of tests that could be simplified by using the functionality that this PR adds. Is that true? If so, is it worth trying to do at least some of that as part of this PR? It would also implicitly test the functionality being added more, since those would be real uses cases.

Yes in theory, we would want to do that. However, in this PR we are just adding the API to be used in a script; we aren't actually adding this code to be used in Trainer. So until we swap out state.py state dict generation code for this one, we don't need to change those other tests or add any E2E tests.

@eracah eracah enabled auto-merge (squash) May 17, 2024 05:06
@eracah eracah merged commit bddf44b into mosaicml:dev May 17, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants