diff --git a/composer/checkpoint/__init__.py b/composer/checkpoint/__init__.py index 7ebfe33cff..84d6c9f4cf 100644 --- a/composer/checkpoint/__init__.py +++ b/composer/checkpoint/__init__.py @@ -3,9 +3,10 @@ """Module for checkpointing API.""" -from composer.checkpoint.state_dict import get_metadata_state_dict, get_model_state_dict +from composer.checkpoint.state_dict import get_metadata_state_dict, get_model_state_dict, get_optim_state_dict __all__ = [ 'get_model_state_dict', + 'get_optim_state_dict', 'get_metadata_state_dict', ] diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index e241436fb7..511fba5a9f 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -6,7 +6,7 @@ import fnmatch import logging import sys -from typing import Any, Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Union import torch from packaging import version @@ -20,6 +20,8 @@ log = logging.getLogger(__name__) +__all__ = ['get_model_state_dict', 'get_optim_state_dict'] + def get_model_state_dict( model: Union[ComposerModel, nn.Module], @@ -89,7 +91,7 @@ def get_model_state_dict( return model_state_dict -def _cast_state_dict_to_precision(state_dict: dict[str, Any], precision: Union[str, torch.dtype]): +def _cast_state_dict_to_precision(state_dict: Dict[str, Any], precision: Union[str, torch.dtype]) -> Dict[str, Any]: if isinstance(precision, str): precision = STR_TO_DTYPE[precision] @@ -156,6 +158,125 @@ def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_st return model_state_dict +def _get_optim_state_dict_with_fsdp_context_manager( + model: nn.Module, + optimizer: torch.optim.Optimizer, + sharded_state_dict: bool, + cpu_offload: bool, +) -> Dict[str, Any]: + """Get the optimizer state dict with the FSDP context manager. + + Args: + model: The model containing the parameters that the optimizer is optimizing. + optimizer: The optimizer to get the state dict from. + sharded_state_dict: Whether the optimizer state dict should be sharded or not. If True, every rank returns the state dict of its shards. + If False, then rank 0 returns the state dict of the entire optimizer. + cpu_offload: Whether to offload the state dict to CPU. + + Returns: + The state dict of the optimizer. + + """ + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, + FullStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictType, + ) + state_dict_type = StateDictType.SHARDED_STATE_DICT if sharded_state_dict else StateDictType.FULL_STATE_DICT + + state_dict_config = ShardedStateDictConfig(offload_to_cpu=cpu_offload, + ) if sharded_state_dict else FullStateDictConfig( + rank0_only=True, + offload_to_cpu=cpu_offload, + ) + optim_state_dict_config = ShardedOptimStateDictConfig( + offload_to_cpu=cpu_offload, + ) if sharded_state_dict else FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=cpu_offload) + with FSDP.state_dict_type( + model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ): + optim_state_dict = FSDP.optim_state_dict(model, optimizer) + return optim_state_dict + + +def get_optim_state_dict( + model: Union[ComposerModel, nn.Module], + optimizer: torch.optim.Optimizer, + sharded_state_dict: bool = False, + precision: str = 'fp32', + include_keys: Optional[Union[str, Sequence[str]]] = None, + ignore_keys: Optional[Union[str, Sequence[str]]] = None, + cpu_offload: Optional[bool] = None, +) -> Dict[str, Any]: + """Generate the state dict of the optimizer. + + Args: + model: The model containing the parameters that the optimizer is optimizing. + optimizer: The optimizer to get the state dict from. + sharded: Whether the optimizer is sharded or not. If True, every rank returns the state dict of its shards. + If False, then rank 0 returns the state dict of the entire optimizer and all other ranks return an empty dict. + precision: The precision of the optimizer. + include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None. + ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None. + cpu_offload: Whether to offload the state dict to CPU. If None, it is set to True if FSDP is enabled with non-sharded state dict and False otherwise. + + Returns: + The state dict of the optimizer. + """ + if include_keys is not None and ignore_keys is not None: + raise ValueError(f'Both {include_keys=} and {ignore_keys=} cannot be non-None.') + + is_fsdp = _is_model_fsdp(model) + if not is_fsdp and sharded_state_dict: + raise ValueError('Sharded optim state dict can only be generated for FSDP models.') + + cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict) + log.debug('Extracting optim state dict') + if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized(): + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict + log.debug('Calling torch get_optimizer_state_dict...') + optim_state_dict: Dict[str, Any] = get_optimizer_state_dict( + model=model, + optimizers=optimizer, + submodules=None, # We extract submodules below + options=StateDictOptions( + full_state_dict=not sharded_state_dict, + cpu_offload=cpu_offload, + ), + ) + else: + if is_fsdp: + log.debug('Calling legacy FSDP context manager to get optim state dict...') + optim_state_dict = _get_optim_state_dict_with_fsdp_context_manager( + model, + optimizer, + sharded_state_dict, + cpu_offload, + ) + else: + optim_state_dict = optimizer.state_dict() + + # For sharded models with non-sharded state dicts, only rank 0 has the full state dict including all the keys + target_state_dict_on_this_rank = (not sharded_state_dict and dist.get_global_rank() == 0) or sharded_state_dict + + if target_state_dict_on_this_rank: + if ignore_keys is not None: + raise NotImplementedError('Ignoring keys in the optimizer state dict is not supported yet.') + if include_keys is not None: + raise NotImplementedError('Ignoring keys in the optimizer state dict is not supported yet.') + + # param_key := index (0,1,2,..., len(model.parameters())-1) for unsharded models. + # param_key := fqn for sharded models. + for param_key, param_state_dict in optim_state_dict['state'].items(): + optim_state_dict['state'][param_key] = _cast_state_dict_to_precision(param_state_dict, precision) + return optim_state_dict + + def get_metadata_state_dict( model: Optional[Union[ComposerModel, nn.Module]] = None, sharded_state_dict: Optional[bool] = None, diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 856fd77213..9618756b83 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -1,14 +1,15 @@ # Copyright 2024 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any +from typing import Any, Dict import pytest import torch from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.optim import adam -from composer.checkpoint import get_metadata_state_dict, get_model_state_dict +from composer.checkpoint import get_metadata_state_dict, get_model_state_dict, get_optim_state_dict from composer.devices import DeviceGPU from composer.utils import dist from tests.common.compare import deep_compare @@ -233,6 +234,198 @@ def test_get_model_state_dict_precision_unsharded_model(precision: str, use_comp assert tens.dtype == precision +def _init_model_and_optimizer( + use_composer_model: bool, + num_classes=3, + batch_size=5, + num_features=8, + take_step=True, + use_fsdp=False, + tensor_type='sharded_tensor', +): + model, loss_fn = _init_model( + use_composer_model, + num_classes=num_classes, + batch_size=batch_size, + num_features=num_features, + use_fsdp=use_fsdp, + tensor_type=tensor_type, + ) + + optimizer = _init_optimizer( + model, + loss_fn, + use_composer_model=use_composer_model, + num_classes=num_classes, + batch_size=batch_size, + num_features=num_features, + take_step=take_step, + ) + + return model, optimizer + + +def _init_model( + use_composer_model: bool = False, + num_classes=3, + batch_size=5, + num_features=8, + use_fsdp=False, + tensor_type='sharded_tensor', +): + if use_composer_model: + model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device='cuda') + loss_fn = model._loss_fn + else: + model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device='cuda') + loss_fn = torch.nn.CrossEntropyLoss() + + if use_fsdp: + fsdp_kwargs: Dict[str, Any] = dict( + use_orig_params=True, + sync_module_states=True, # To enable easy comparison between rank 0 unsharded model and full state dict + ) + + if tensor_type == 'dtensor': + from torch.distributed.device_mesh import init_device_mesh + device_mesh = init_device_mesh('cuda', (2,)) + fsdp_kwargs['device_mesh'] = device_mesh + + model = FSDP( + model, + **fsdp_kwargs, + ) + + return model, loss_fn + + +def _init_optimizer( + model, + loss_fn, + use_composer_model: bool = False, + num_classes=3, + batch_size=5, + num_features=8, + take_step=True, +): + inputs = torch.randn(batch_size, num_features, device='cuda') + targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device='cuda', dtype=torch.long) + batch = (inputs, targets) if use_composer_model else inputs + optimizer = adam.Adam(model.parameters()) + outputs = model(batch) + loss = loss_fn(outputs, targets) + loss.backward() + if take_step: + optimizer.step() + return optimizer + + +@pytest.mark.gpu +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) + optim_state_dict = get_optim_state_dict(model, optimizer) + + # Dict mapping parameter index to optimizer state for that parameter. + osd_state = optim_state_dict['state'] + # Dict mapping parameter itself to optimizer state for that parameter. + optim_state = optimizer.state + + # Make sure optimizer state is the same between the state dict and the optimizer object. + for osd_param_state, opt_param_state in zip(osd_state.values(), optim_state.values()): + deep_compare(osd_param_state, opt_param_state) + + # Make sure the optimizer state in the state dict is the same shape as the parameter it corresponds to. + # Because model is unsharded the optimizer state should have keys corresponding to the index of the model's parameters. + # e.g. if the model has 3 parameters, the optimizer state dict keys would be (0,1,2). + params = list(model.parameters()) + param_dict = dict(list(model.named_parameters())) + for param_key, param_state in osd_state.items(): + if isinstance(param_key, str): + param = param_dict[param_key] + else: + param = params[param_key] + assert param.shape == param_state['exp_avg'].shape + assert param.shape == param_state['exp_avg_sq'].shape + + # Make sure param groups between the state dict and the optimizer object are the same. + for osd_group, opt_group in zip(optim_state_dict['param_groups'], optimizer.param_groups): + # Only params should differ between the two. + # * in the optimizer state dict params will be indices into the model's parameters list. + # * in the optimizer object params will be the actual parameter tensors. + deep_compare(osd_group, opt_group, ignore_keys=['params']) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'precision', + [ + torch.float32, + torch.float16, + torch.bfloat16, + ], +) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_optim_state_dict_precision_unsharded_model(precision: str, use_composer_model: bool): + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) + optim_state_dict = get_optim_state_dict(model, optimizer, precision=precision) + for param_state in optim_state_dict['state'].values(): + assert param_state['exp_avg'].dtype == precision + assert param_state['exp_avg_sq'].dtype == precision + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_optim_dict_full_for_sharded_model(world_size, tensor_type, use_composer_model: bool): + if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): + pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') + + model, optimizer = _init_model_and_optimizer( + use_composer_model=use_composer_model, + take_step=True, + use_fsdp=True, + tensor_type=tensor_type, + ) + optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=False) + + with FSDP.summon_full_params(model): + # Make sure the optimizer state in the state dict is the same shape as the parameter it corresponds to. + fqn_to_shape_map = {fqn: param.shape for fqn, param in model.named_parameters()} + if dist.get_global_rank() == 0: + # Because model is sharded, the state dict should have the same keys as the model's parameters. + for fqn, param_state in optim_state_dict['state'].items(): + model_param_shape = fqn_to_shape_map[fqn] + assert model_param_shape == param_state['exp_avg'].shape + assert model_param_shape == param_state['exp_avg_sq'].shape + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_optim_dict_sharded_for_sharded_model(world_size, tensor_type, use_composer_model: bool): + if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): + pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') + + model, optimizer = _init_model_and_optimizer( + use_composer_model=use_composer_model, + take_step=True, + use_fsdp=True, + tensor_type=tensor_type, + ) + model_state_dict = get_model_state_dict(model, sharded_state_dict=True) + optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=True) + + # Check to make sure on every rank optimizer state name and shape matches model's + fqn_to_shape_map = {fqn: param.shape for fqn, param in model_state_dict.items()} + for fqn, param_state in optim_state_dict['state'].items(): + model_param_shape = fqn_to_shape_map[fqn] + assert model_param_shape == param_state['exp_avg'].shape + assert model_param_shape == param_state['exp_avg_sq'].shape + + @pytest.mark.gpu @world_size(1, 2) def test_get_metadata_empty_call(world_size): diff --git a/tests/common/compare.py b/tests/common/compare.py index c35334fcb1..432ac55dfd 100644 --- a/tests/common/compare.py +++ b/tests/common/compare.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import datetime -from typing import Any, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import torch @@ -12,7 +12,7 @@ from composer.core.time import TimeUnit -def deep_compare(item1: Any, item2: Any, atol: float = 0.0, rtol: float = 0.0): +def deep_compare(item1: Any, item2: Any, atol: float = 0.0, rtol: float = 0.0, ignore_keys: Optional[List[str]] = None): """Compare two items recursively. Supports dicts, lists, tuples, tensors, numpy arrays, Composer Time objects, and callables. Args: @@ -21,10 +21,17 @@ def deep_compare(item1: Any, item2: Any, atol: float = 0.0, rtol: float = 0.0): atol (bool): Atol tolerance for torch tensors and numpy arrays (default: 0.0) rtol (float): Rtol tolerance for torch tensors and numpy arrays (default: 0.0) """ - return _check_item(item1, item2, path='', atol=atol, rtol=rtol) + return _check_item(item1, item2, path='', atol=atol, rtol=rtol, ignore_keys=ignore_keys) -def _check_item(item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: float = 0.0): +def _check_item( + item1: Any, + item2: Any, + path: str, + rtol: float = 0.0, + atol: float = 0.0, + ignore_keys: Optional[List[str]] = None, +): if item1 is None: assert item2 is None, f'{path} differs: {item1} != {item2}' return @@ -45,7 +52,7 @@ def _check_item(item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: floa return if isinstance(item1, dict): assert isinstance(item2, dict), f'{path} differs: {item1} != {item2}' - _check_dict_recursively(item1, item2, path, atol=atol, rtol=rtol) + _check_dict_recursively(item1, item2, path, atol=atol, rtol=rtol, ignore_keys=ignore_keys) return if isinstance(item1, (tuple, list)): assert isinstance(item2, type(item1)), f'{path} differs: {item1} != {item2}' @@ -89,9 +96,18 @@ def _check_list_recursively( _check_item(item1, item2, path=f'{path}/{i}', atol=atol, rtol=rtol) -def _check_dict_recursively(dict1: dict[str, Any], dict2: dict[str, Any], path: str, atol: float, rtol: float): +def _check_dict_recursively( + dict1: Dict[str, Any], + dict2: Dict[str, Any], + path: str, + atol: float, + rtol: float, + ignore_keys: Optional[List[str]] = None, +): assert len(dict1) == len(dict2), f'{path} differs: {dict1} != {dict2}' for k, val1 in dict1.items(): + if ignore_keys is not None and k in ignore_keys: + continue val2 = dict2[k] # special case fused optimizer to allow comparing a GPU checkpoint with a CPU checkpoint diff --git a/tests/common/models.py b/tests/common/models.py index e7fd084744..2310a03a82 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -120,10 +120,10 @@ def forward(self, x): # are not submodules of EvenSimplerMLP, like they are in SimpleMLP. class EvenSimplerMLP(torch.nn.Module): - def __init__(self, num_features: int, device: str = 'cpu'): + def __init__(self, num_features: int, device: str = 'cpu', num_out_features: int = 3): super().__init__() fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + fc2 = torch.nn.Linear(num_features, num_out_features, device=device, bias=False) self.module = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) @@ -137,7 +137,7 @@ class SimpleComposerMLP(ComposerClassifier): def __init__(self, num_features: int, device: str, num_classes: int = 3): fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + fc2 = torch.nn.Linear(num_features, num_classes, device=device, bias=False) net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) super().__init__(num_classes=num_classes, module=net)