Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ckpt-rewr] Get Model State Dict Util Function #3250

Merged
merged 40 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
17d80d5
first commit get model sd
eracah May 3, 2024
3bb5785
Add test stub
eracah May 3, 2024
1ebbd87
Add test stub
eracah May 6, 2024
cca228b
Added unit tests for non-sharded use cases
eracah May 7, 2024
37e9753
Add support for ComposerModel (+ Precision tweak)
eracah May 7, 2024
4558182
Add sharded state dict support
eracah May 8, 2024
23434c0
add precision test
eracah May 8, 2024
eb45d05
Add precision test for sharded state dict
eracah May 8, 2024
f1acf26
pre-commit
eracah May 10, 2024
299aae5
pre-commit
eracah May 10, 2024
5028039
Merge branch 'dev' into get-model-sd
eracah May 10, 2024
c08abd1
mark gpu
eracah May 10, 2024
b4f380c
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 10, 2024
a3664ca
pre-commit
eracah May 11, 2024
8ab34f1
add error for sharded + non-fsdp
eracah May 15, 2024
ab00b11
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 15, 2024
2bbdc27
change gating
eracah May 15, 2024
5cf52cb
got tests to pass
eracah May 15, 2024
da14580
fix tests
eracah May 15, 2024
f20ecb3
Merge branch 'dev' into get-model-sd
eracah May 15, 2024
24aecbf
Merge branch 'dev' into get-model-sd
eracah May 16, 2024
fae7746
docstring
eracah May 16, 2024
ff8d91c
Update composer/checkpoint/state_dict.py
eracah May 16, 2024
aba4900
Update composer/checkpoint/state_dict.py
eracah May 16, 2024
1975a56
pc
eracah May 16, 2024
d72e2d4
fix version
eracah May 16, 2024
f5e4f53
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 16, 2024
f3a7cce
pc
eracah May 16, 2024
50c2308
pc
eracah May 16, 2024
7fec0c2
pre-commit
eracah May 17, 2024
2e3bf04
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 17, 2024
2941373
Addressed some comments
eracah May 17, 2024
307c780
pre-commit
eracah May 17, 2024
11bdd89
add comments for new simple models
eracah May 17, 2024
677002b
remove todo's
eracah May 17, 2024
01c1560
remove docstring arg
eracah May 17, 2024
5ccbe12
change scope name for get_model_state_dict
eracah May 17, 2024
44468db
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 17, 2024
aee42ff
pre-commit
eracah May 17, 2024
94ab3da
Merge branch 'dev' into get-model-sd
eracah May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions composer/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Module for checkpointing API."""

from composer.checkpoint.state_dict import get_model_state_dict

__all__ = [
'get_model_state_dict',
]
161 changes: 161 additions & 0 deletions composer/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Useful functions for generating state dicts and manipulating them."""

import fnmatch
import logging
from typing import Any, Dict, Optional, Sequence, Union

import torch
from packaging import version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel

from composer.models import ComposerModel
from composer.utils import dist

log = logging.getLogger(__name__)


def get_model_state_dict(
model: Union[ComposerModel, nn.Module],
sharded_state_dict: bool,
precision: Union[str, torch.dtype] = 'fp32',
include_keys: Optional[Union[str, Sequence[str]]] = None,
ignore_keys: Optional[Union[str, Sequence[str]]] = None,
cpu_offload: Optional[bool] = None,
) -> Dict[str, Any]:
"""Generate the state dict of the model.

Args:
model: The model to get the state dict from.
sharded_state_dict: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards.
eracah marked this conversation as resolved.
Show resolved Hide resolved
If False, then rank 0 returns the state dict of the entire model.
eracah marked this conversation as resolved.
Show resolved Hide resolved
precision: The precision of the model. Can be specified as a string ('fp32', 'fp16', 'bf16') or a torch.dtype.
include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None.
ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None.
cpu_offload: Whether to offload the state dict to CPU. If None, it is set to True if FSDP is enabled, False otherwise.

Returns:
The state dict of the model.
"""
if include_keys is not None and ignore_keys is not None:
raise ValueError(f'Both {include_keys=} and {ignore_keys=} cannot be non-None.')

is_fsdp = _is_model_fsdp(model)
if not is_fsdp and sharded_state_dict:
raise ValueError('Sharded state dict can only be generated for FSDP models.')
cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict)

log.debug('Extracting model state dict')
if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions
from torch.distributed.checkpoint.state_dict import get_model_state_dict as torch_get_model_state_dict

use_unsharded_state_dict = not sharded_state_dict

log.debug('Calling torch get_model_state_dict...')
model_state_dict = torch_get_model_state_dict(
model=model,
submodules=None, # We extract submodules below
options=StateDictOptions(
full_state_dict=use_unsharded_state_dict,
cpu_offload=cpu_offload,
),
)
else:
if is_fsdp:
log.debug('Calling legacy FSDP context manager to get model state dict...')
model_state_dict = _get_model_state_dict_with_fsdp_context_manager(model, sharded_state_dict, cpu_offload)
else:
log.debug('Calling model.state_dict() for non-FSDP model...')
model_state_dict = model.state_dict()
if isinstance(model, DistributedDataParallel):
nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.')

if include_keys is not None:
model_state_dict = _extract_keys_from_state_dict(model_state_dict, include_keys)

if ignore_keys is not None:
model_state_dict = _remove_keys_from_state_dict(model_state_dict, ignore_keys)

model_state_dict = _cast_state_dict_to_precision(state_dict=model_state_dict, precision=precision)

log.debug('Finished extracting model state dict')
return model_state_dict


STR_TO_DTYPE = {
'fp32': torch.float32,
'fp16': torch.float16,
'bf16': torch.bfloat16,
}
eracah marked this conversation as resolved.
Show resolved Hide resolved


def _cast_state_dict_to_precision(state_dict: Dict[str, Any], precision: Union[str, torch.dtype]):
if isinstance(precision, str):
precision = STR_TO_DTYPE[precision]

new_state_dict = {k: v.to(precision) for k, v in state_dict.items()}
return new_state_dict


def _extract_keys_from_state_dict(state_dict: Dict[str, Any], include_keys: Union[str, Sequence[str]]):
if isinstance(include_keys, str):
include_keys = [include_keys]
new_state_dict = {k: v for k, v in state_dict.items() if any(fnmatch.fnmatch(k, key) for key in include_keys)}

return new_state_dict


def _remove_keys_from_state_dict(state_dict: Dict[str, Any], ignore_keys: Union[str, Sequence[str]]):
if isinstance(ignore_keys, str):
ignore_keys = [ignore_keys]
new_state_dict = {k: v for k, v in state_dict.items() if not any(fnmatch.fnmatch(k, key) for key in ignore_keys)}
return new_state_dict


def _is_model_fsdp(model) -> bool:
"""Indicates if FSDP is enabled.

Args:
model: The model to check if FSDP is enabled.

Returns:
True if FSDP is enabled, False otherwise.

"""
for module in model.modules():
if isinstance(module, FSDP):
return True
return False


def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_state_dict: bool,
cpu_offload: bool) -> Dict[str, Any]:
"""Get the model state dict with the FSDP context manager.

Args:
model: The model to get the state dict from.
sharded: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards.
If False, then rank 0 returns the state dict of the entire model.

Returns:
The state dict of the model.
"""
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullStateDictConfig,
ShardedStateDictConfig,
StateDictType,
)
state_dict_type = StateDictType.SHARDED_STATE_DICT if sharded_state_dict else StateDictType.FULL_STATE_DICT
state_dict_config = ShardedStateDictConfig(offload_to_cpu=cpu_offload,
) if sharded_state_dict else FullStateDictConfig(
rank0_only=True,
offload_to_cpu=cpu_offload,
)
with FSDP.state_dict_type(model, state_dict_type=state_dict_type, state_dict_config=state_dict_config):
model_state_dict = model.state_dict()
return model_state_dict
Loading
Loading