Skip to content

Commit

Permalink
[ckpt-rewr] Get Model State Dict Util Function (#3250)
Browse files Browse the repository at this point in the history
  • Loading branch information
eracah authored May 17, 2024
1 parent 435c295 commit bddf44b
Show file tree
Hide file tree
Showing 7 changed files with 435 additions and 0 deletions.
10 changes: 10 additions & 0 deletions composer/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Module for checkpointing API."""

from composer.checkpoint.state_dict import get_model_state_dict

__all__ = [
'get_model_state_dict',
]
154 changes: 154 additions & 0 deletions composer/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Useful functions for generating state dicts and manipulating them."""

import fnmatch
import logging
from typing import Any, Dict, Optional, Sequence, Union

import torch
from packaging import version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel

from composer.models import ComposerModel
from composer.utils import STR_TO_DTYPE, dist

log = logging.getLogger(__name__)


def get_model_state_dict(
model: Union[ComposerModel, nn.Module],
sharded_state_dict: bool = False,
precision: Union[str, torch.dtype] = 'fp32',
include_keys: Optional[Union[str, Sequence[str]]] = None,
ignore_keys: Optional[Union[str, Sequence[str]]] = None,
cpu_offload: Optional[bool] = None,
) -> Dict[str, Any]:
"""Generate the state dict of the model.
Args:
model: The model to get the state dict from.
sharded_state_dict: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards.
If False, then rank 0 returns the state dict of the entire model and the other ranks return a dict of their shards. Default is False.
precision: The precision of the model. Can be specified as a string ('fp32', 'fp16', 'bf16') or a torch.dtype.
include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None.
ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None.
cpu_offload: Whether to offload the state dict to CPU. If None, it is set to True if FSDP is enabled, False otherwise.
Returns:
The state dict of the model.
"""
if include_keys is not None and ignore_keys is not None:
raise ValueError(f'Both {include_keys=} and {ignore_keys=} cannot be non-None.')

is_fsdp = _is_model_fsdp(model)
if not is_fsdp and sharded_state_dict:
raise ValueError('Sharded state dict can only be generated for FSDP models.')
cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict)

log.debug('Extracting model state dict')
if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized():
from torch.distributed.checkpoint import state_dict as DCPSD # Distributed Checkpoint State Dict
from torch.distributed.checkpoint.state_dict import StateDictOptions

use_unsharded_state_dict = not sharded_state_dict

log.debug('Calling torch get_model_state_dict...')
model_state_dict = DCPSD.get_model_state_dict(
model=model,
submodules=None, # We extract submodules below
options=StateDictOptions(
full_state_dict=use_unsharded_state_dict,
cpu_offload=cpu_offload,
),
)
else:
if is_fsdp:
log.debug('Calling legacy FSDP context manager to get model state dict...')
model_state_dict = _get_model_state_dict_with_fsdp_context_manager(model, sharded_state_dict, cpu_offload)
else:
log.debug('Calling model.state_dict() for non-FSDP model...')
model_state_dict = model.state_dict()
if isinstance(model, DistributedDataParallel):
nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.')

if include_keys is not None:
model_state_dict = _extract_keys_from_state_dict(model_state_dict, include_keys)

if ignore_keys is not None:
model_state_dict = _remove_keys_from_state_dict(model_state_dict, ignore_keys)

model_state_dict = _cast_state_dict_to_precision(state_dict=model_state_dict, precision=precision)

log.debug('Finished extracting model state dict')
return model_state_dict


def _cast_state_dict_to_precision(state_dict: Dict[str, Any], precision: Union[str, torch.dtype]):
if isinstance(precision, str):
precision = STR_TO_DTYPE[precision]

new_state_dict = {k: v.to(precision) for k, v in state_dict.items()}
return new_state_dict


def _extract_keys_from_state_dict(state_dict: Dict[str, Any], include_keys: Union[str, Sequence[str]]):
if isinstance(include_keys, str):
include_keys = [include_keys]
new_state_dict = {k: v for k, v in state_dict.items() if any(fnmatch.fnmatch(k, key) for key in include_keys)}

return new_state_dict


def _remove_keys_from_state_dict(state_dict: Dict[str, Any], ignore_keys: Union[str, Sequence[str]]):
if isinstance(ignore_keys, str):
ignore_keys = [ignore_keys]
new_state_dict = {k: v for k, v in state_dict.items() if not any(fnmatch.fnmatch(k, key) for key in ignore_keys)}
return new_state_dict


def _is_model_fsdp(model) -> bool:
"""Indicates if FSDP is enabled.
Args:
model: The model to check if FSDP is enabled.
Returns:
True if FSDP is enabled, False otherwise.
"""
for module in model.modules():
if isinstance(module, FSDP):
return True
return False


def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_state_dict: bool,
cpu_offload: bool) -> Dict[str, Any]:
"""Get the model state dict with the FSDP context manager.
Args:
model: The model to get the state dict from.
sharded: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards.
If False, then rank 0 returns the state dict of the entire model.
Returns:
The state dict of the model.
"""
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullStateDictConfig,
ShardedStateDictConfig,
StateDictType,
)
state_dict_type = StateDictType.SHARDED_STATE_DICT if sharded_state_dict else StateDictType.FULL_STATE_DICT
state_dict_config = ShardedStateDictConfig(offload_to_cpu=cpu_offload,
) if sharded_state_dict else FullStateDictConfig(
rank0_only=True,
offload_to_cpu=cpu_offload,
)
with FSDP.state_dict_type(model, state_dict_type=state_dict_type, state_dict_config=state_dict_config):
model_state_dict = model.state_dict()
return model_state_dict
2 changes: 2 additions & 0 deletions composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from composer.utils.inference import ExportFormat, Transform, export_for_inference, export_with_logger, quantize_dynamic
from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection
from composer.utils.misc import (
STR_TO_DTYPE,
add_vision_dataset_transform,
create_interval_scheduler,
get_free_tcp_port,
Expand Down Expand Up @@ -143,4 +144,5 @@
'CliCompressor',
'get_compressor',
'KNOWN_COMPRESSORS',
'STR_TO_DTYPE',
]
7 changes: 7 additions & 0 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,17 @@
'model_eval_mode',
'create_interval_scheduler',
'add_vision_dataset_transform',
'STR_TO_DTYPE',
]

log = logging.getLogger(__name__)

STR_TO_DTYPE = {
'fp32': torch.float32,
'fp16': torch.float16,
'bf16': torch.bfloat16,
}


def create_interval_scheduler(
interval: Union[str, int, 'Time'],
Expand Down
Loading

0 comments on commit bddf44b

Please sign in to comment.