From bddf44ba6c2645b32cd5916ff5533deeda109f82 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Thu, 16 May 2024 22:48:32 -0700 Subject: [PATCH] [ckpt-rewr] Get Model State Dict Util Function (#3250) --- composer/checkpoint/__init__.py | 10 ++ composer/checkpoint/state_dict.py | 154 ++++++++++++++++++ composer/utils/__init__.py | 2 + composer/utils/misc.py | 7 + tests/checkpoint/test_state_dict.py | 232 ++++++++++++++++++++++++++++ tests/common/compare.py | 3 + tests/common/models.py | 27 ++++ 7 files changed, 435 insertions(+) create mode 100644 composer/checkpoint/__init__.py create mode 100644 composer/checkpoint/state_dict.py create mode 100644 tests/checkpoint/test_state_dict.py diff --git a/composer/checkpoint/__init__.py b/composer/checkpoint/__init__.py new file mode 100644 index 0000000000..be9c380c2d --- /dev/null +++ b/composer/checkpoint/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Module for checkpointing API.""" + +from composer.checkpoint.state_dict import get_model_state_dict + +__all__ = [ + 'get_model_state_dict', +] diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py new file mode 100644 index 0000000000..5417188466 --- /dev/null +++ b/composer/checkpoint/state_dict.py @@ -0,0 +1,154 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Useful functions for generating state dicts and manipulating them.""" + +import fnmatch +import logging +from typing import Any, Dict, Optional, Sequence, Union + +import torch +from packaging import version +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.nn.parallel import DistributedDataParallel + +from composer.models import ComposerModel +from composer.utils import STR_TO_DTYPE, dist + +log = logging.getLogger(__name__) + + +def get_model_state_dict( + model: Union[ComposerModel, nn.Module], + sharded_state_dict: bool = False, + precision: Union[str, torch.dtype] = 'fp32', + include_keys: Optional[Union[str, Sequence[str]]] = None, + ignore_keys: Optional[Union[str, Sequence[str]]] = None, + cpu_offload: Optional[bool] = None, +) -> Dict[str, Any]: + """Generate the state dict of the model. + + Args: + model: The model to get the state dict from. + sharded_state_dict: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards. + If False, then rank 0 returns the state dict of the entire model and the other ranks return a dict of their shards. Default is False. + precision: The precision of the model. Can be specified as a string ('fp32', 'fp16', 'bf16') or a torch.dtype. + include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None. + ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None. + cpu_offload: Whether to offload the state dict to CPU. If None, it is set to True if FSDP is enabled, False otherwise. + + Returns: + The state dict of the model. + """ + if include_keys is not None and ignore_keys is not None: + raise ValueError(f'Both {include_keys=} and {ignore_keys=} cannot be non-None.') + + is_fsdp = _is_model_fsdp(model) + if not is_fsdp and sharded_state_dict: + raise ValueError('Sharded state dict can only be generated for FSDP models.') + cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict) + + log.debug('Extracting model state dict') + if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized(): + from torch.distributed.checkpoint import state_dict as DCPSD # Distributed Checkpoint State Dict + from torch.distributed.checkpoint.state_dict import StateDictOptions + + use_unsharded_state_dict = not sharded_state_dict + + log.debug('Calling torch get_model_state_dict...') + model_state_dict = DCPSD.get_model_state_dict( + model=model, + submodules=None, # We extract submodules below + options=StateDictOptions( + full_state_dict=use_unsharded_state_dict, + cpu_offload=cpu_offload, + ), + ) + else: + if is_fsdp: + log.debug('Calling legacy FSDP context manager to get model state dict...') + model_state_dict = _get_model_state_dict_with_fsdp_context_manager(model, sharded_state_dict, cpu_offload) + else: + log.debug('Calling model.state_dict() for non-FSDP model...') + model_state_dict = model.state_dict() + if isinstance(model, DistributedDataParallel): + nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.') + + if include_keys is not None: + model_state_dict = _extract_keys_from_state_dict(model_state_dict, include_keys) + + if ignore_keys is not None: + model_state_dict = _remove_keys_from_state_dict(model_state_dict, ignore_keys) + + model_state_dict = _cast_state_dict_to_precision(state_dict=model_state_dict, precision=precision) + + log.debug('Finished extracting model state dict') + return model_state_dict + + +def _cast_state_dict_to_precision(state_dict: Dict[str, Any], precision: Union[str, torch.dtype]): + if isinstance(precision, str): + precision = STR_TO_DTYPE[precision] + + new_state_dict = {k: v.to(precision) for k, v in state_dict.items()} + return new_state_dict + + +def _extract_keys_from_state_dict(state_dict: Dict[str, Any], include_keys: Union[str, Sequence[str]]): + if isinstance(include_keys, str): + include_keys = [include_keys] + new_state_dict = {k: v for k, v in state_dict.items() if any(fnmatch.fnmatch(k, key) for key in include_keys)} + + return new_state_dict + + +def _remove_keys_from_state_dict(state_dict: Dict[str, Any], ignore_keys: Union[str, Sequence[str]]): + if isinstance(ignore_keys, str): + ignore_keys = [ignore_keys] + new_state_dict = {k: v for k, v in state_dict.items() if not any(fnmatch.fnmatch(k, key) for key in ignore_keys)} + return new_state_dict + + +def _is_model_fsdp(model) -> bool: + """Indicates if FSDP is enabled. + + Args: + model: The model to check if FSDP is enabled. + + Returns: + True if FSDP is enabled, False otherwise. + + """ + for module in model.modules(): + if isinstance(module, FSDP): + return True + return False + + +def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_state_dict: bool, + cpu_offload: bool) -> Dict[str, Any]: + """Get the model state dict with the FSDP context manager. + + Args: + model: The model to get the state dict from. + sharded: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards. + If False, then rank 0 returns the state dict of the entire model. + + Returns: + The state dict of the model. + """ + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullStateDictConfig, + ShardedStateDictConfig, + StateDictType, + ) + state_dict_type = StateDictType.SHARDED_STATE_DICT if sharded_state_dict else StateDictType.FULL_STATE_DICT + state_dict_config = ShardedStateDictConfig(offload_to_cpu=cpu_offload, + ) if sharded_state_dict else FullStateDictConfig( + rank0_only=True, + offload_to_cpu=cpu_offload, + ) + with FSDP.state_dict_type(model, state_dict_type=state_dict_type, state_dict_config=state_dict_config): + model_state_dict = model.state_dict() + return model_state_dict diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 15d089a475..ea5f4cd14e 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -49,6 +49,7 @@ from composer.utils.inference import ExportFormat, Transform, export_for_inference, export_with_logger, quantize_dynamic from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection from composer.utils.misc import ( + STR_TO_DTYPE, add_vision_dataset_transform, create_interval_scheduler, get_free_tcp_port, @@ -143,4 +144,5 @@ 'CliCompressor', 'get_compressor', 'KNOWN_COMPRESSORS', + 'STR_TO_DTYPE', ] diff --git a/composer/utils/misc.py b/composer/utils/misc.py index 242095e0e4..88a1366336 100644 --- a/composer/utils/misc.py +++ b/composer/utils/misc.py @@ -27,10 +27,17 @@ 'model_eval_mode', 'create_interval_scheduler', 'add_vision_dataset_transform', + 'STR_TO_DTYPE', ] log = logging.getLogger(__name__) +STR_TO_DTYPE = { + 'fp32': torch.float32, + 'fp16': torch.float16, + 'bf16': torch.bfloat16, +} + def create_interval_scheduler( interval: Union[str, int, 'Time'], diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py new file mode 100644 index 0000000000..8b40c83bcc --- /dev/null +++ b/tests/checkpoint/test_state_dict.py @@ -0,0 +1,232 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict + +import pytest +import torch +from packaging import version +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from composer.checkpoint import get_model_state_dict +from composer.utils import dist +from tests.common.compare import deep_compare +from tests.common.markers import world_size +from tests.common.models import EvenSimplerMLP, SimpleComposerMLP + + +@pytest.mark.gpu +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_model_state_dict_unsharded_model(use_composer_model: bool): + if use_composer_model: + model = SimpleComposerMLP(num_features=8, device='cuda') + else: + model = EvenSimplerMLP(num_features=8, device='cuda') + model_state_dict = get_model_state_dict(model, sharded_state_dict=False, include_keys=None, ignore_keys=None) + for name, param in model.named_parameters(): + print(name) + assert name in model_state_dict + assert torch.equal(model_state_dict[name], param) + + +@pytest.mark.gpu +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_model_state_dict_include(use_composer_model: bool): + if use_composer_model: + model = SimpleComposerMLP(num_features=8, device='cuda') + else: + model = EvenSimplerMLP(num_features=8, device='cuda') + model_state_dict = get_model_state_dict(model, sharded_state_dict=False, include_keys=['module.0.weight']) + assert set(model_state_dict.keys()) == {'module.0.weight'} + + model_state_dict = get_model_state_dict(model, sharded_state_dict=False, include_keys='module.2*') + assert set(model_state_dict.keys()) == {'module.2.weight'} + + +@pytest.mark.gpu +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_model_state_dict_ignore(use_composer_model: bool): + if use_composer_model: + model = SimpleComposerMLP(num_features=8, device='cuda') + else: + model = EvenSimplerMLP(num_features=8, device='cuda') + + model_state_dict = get_model_state_dict(model, sharded_state_dict=False, ignore_keys='module.2.weight') + assert set(model_state_dict.keys()) == {'module.0.weight'} + + model_state_dict = get_model_state_dict(model, sharded_state_dict=False, ignore_keys=['module.2*']) + assert set(model_state_dict.keys()) == {'module.0.weight'} + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_model_state_dict_full_for_sharded_model(world_size, tensor_type, use_composer_model: bool): + if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): + pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') + if use_composer_model: + model = SimpleComposerMLP(num_features=16, device='cuda') + else: + model = EvenSimplerMLP(num_features=16, device='cuda') + + # Torch flattens model params in place after wrapped with FSDP, so we need to cache unflattened params now + # before fsdp wrapping in order to keep pre-sharding shapes. + pre_shard_state_dict = get_model_state_dict( + model, + sharded_state_dict=False, + ) + fsdp_kwargs: Dict[str, Any] = dict( + use_orig_params=True, + sync_module_states=True, # To enable easy comparison between rank 0 unsharded model and full state dict + ) + + if tensor_type == 'dtensor': + from torch.distributed.device_mesh import init_device_mesh + device_mesh = init_device_mesh('cuda', (2,)) + fsdp_kwargs['device_mesh'] = device_mesh + + sharded_model = FSDP( + model, + **fsdp_kwargs, + ) + + post_shard_full_state_dict = get_model_state_dict(sharded_model, sharded_state_dict=False) + + if dist.get_global_rank() == 0: + deep_compare(pre_shard_state_dict, post_shard_full_state_dict) + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_model_state_dict_sharded(world_size, tensor_type, use_composer_model: bool): + if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): + pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') + + if use_composer_model: + model = SimpleComposerMLP(num_features=16, device='cuda') + else: + model = EvenSimplerMLP(num_features=16, device='cuda') + + # Torch flattens model params in place after wrapped with FSDP, so we need to cache unflattened params now + # before fsdp wrapping in order to keep pre-sharding shapes. + pre_shard_full_state_dict = get_model_state_dict( + model, + sharded_state_dict=False, + ) + + fsdp_kwargs: Dict[str, Any] = dict( + use_orig_params=True, + sync_module_states=True, # To enable easy comparison between rank 0 unsharded model and full state dict + ) + + if tensor_type == 'dtensor': + from torch.distributed.device_mesh import init_device_mesh + device_mesh = init_device_mesh('cuda', (2,)) + fsdp_kwargs.update(device_mesh=device_mesh) + + sharded_model = FSDP( + model, + **fsdp_kwargs, + ) + + post_shard_sharded_sd = get_model_state_dict(sharded_model, sharded_state_dict=True) + + # In order to test if the sharded state dict is correct we go through this process: + # 1. Transform the each rank's state dict's values by extracting the the local tensor from the ShardedTensor object + # 2. Gather each rank's state dicts + # 3. Make a "reconstructed" full state dict by, for each key, concatenating all the tensor shards into one big tensor + # 4. Compare this "reconstructed" full state dict to the original model's state dict to ensure they are the same. + local_tensor_sd = { + n: (p.local_tensor() if tensor_type == 'sharded_tensor' else p.to_local()) + for n, p in post_shard_sharded_sd.items() + } + all_local_tensor_sd = dist.all_gather_object(local_tensor_sd) + post_shard_reconstructed_full_sd = { + n: torch.cat( + [sd[n].cuda() for sd in all_local_tensor_sd], + dim=0, # dim=0 because fsdp shards each tensor on the 0th dimension + ) for n in pre_shard_full_state_dict.keys() + } + if dist.get_global_rank() == 0: + deep_compare(pre_shard_full_state_dict, post_shard_reconstructed_full_sd) + + +@world_size(2) +@pytest.mark.gpu +@pytest.mark.parametrize( + 'precision', + [ + torch.float32, + torch.float16, + torch.bfloat16, + ], +) +@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_model_state_dict_precision_sharded_model( + world_size, + tensor_type, + precision: str, + use_composer_model: bool, +): + if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): + pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') + if use_composer_model: + model = SimpleComposerMLP(num_features=8, device='cuda') + else: + model = EvenSimplerMLP(num_features=8, device='cuda') + + fsdp_kwargs: Dict[str, Any] = dict( + use_orig_params=True, + sync_module_states=True, # To enable easy comparison between rank 0 unsharded model and full state dict + ) + + if tensor_type == 'dtensor': + from torch.distributed.device_mesh import init_device_mesh + device_mesh = init_device_mesh('cuda', (2,)) + fsdp_kwargs.update(device_mesh=device_mesh) + + sharded_model = FSDP( + model, + **fsdp_kwargs, + ) + + model_state_dict = get_model_state_dict( + sharded_model, + precision=precision, + sharded_state_dict=True, + include_keys=None, + ignore_keys=None, + ) + for sharded_tens in model_state_dict.values(): + local_tensor = sharded_tens.local_tensor() if tensor_type == 'sharded_tensor' else sharded_tens.to_local() + assert local_tensor.dtype == precision + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'precision', + [ + torch.float32, + torch.float16, + torch.bfloat16, + ], +) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_model_state_dict_precision_unsharded_model(precision: str, use_composer_model: bool): + if use_composer_model: + model = SimpleComposerMLP(num_features=8, device='cuda') + else: + model = EvenSimplerMLP(num_features=8, device='cuda') + model_state_dict = get_model_state_dict( + model, + precision=precision, + sharded_state_dict=False, + include_keys=None, + ignore_keys=None, + ) + for tens in model_state_dict.values(): + assert tens.dtype == precision diff --git a/tests/common/compare.py b/tests/common/compare.py index 870cf24f46..942fa67504 100644 --- a/tests/common/compare.py +++ b/tests/common/compare.py @@ -34,6 +34,9 @@ def _check_item(item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: floa return if isinstance(item1, torch.Tensor): assert isinstance(item2, torch.Tensor) + if item1.device != item2.device: + item1 = item1.cpu() + item2 = item2.cpu() assert item1.allclose(item2, rtol=rtol, atol=atol), f'{path} differs' return if isinstance(item1, np.ndarray): diff --git a/tests/common/models.py b/tests/common/models.py index f1730c93d2..ea3546ad1a 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -113,6 +113,33 @@ def forward(self, x): return self.net(x) +# We use this Module to test state dict generation because fc1 and fc2 +# are not submodules of EvenSimplerMLP, like they are in SimpleMLP. +class EvenSimplerMLP(torch.nn.Module): + + def __init__(self, num_features: int, device: str): + super().__init__() + fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + + self.module = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) + + def forward(self, x): + return self.module(x) + + +# This model is used when you want a SimpleMLP, but you want to explicitly +# test ComposerModels instead of nn.Module. +class SimpleComposerMLP(ComposerClassifier): + + def __init__(self, num_features: int, device: str, num_classes: int = 3): + fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + + net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) + super().__init__(num_classes=num_classes, module=net) + + class SimpleWeightTiedModel(ComposerClassifier): """Small classification model with tied weights. Typically this model will be used to test weight tying w/ FSDP