Skip to content

Commit

Permalink
[ckpt-rewr] Get Optim State Dict Util API (#3299)
Browse files Browse the repository at this point in the history
  • Loading branch information
eracah authored May 31, 2024
1 parent 3241a85 commit 8b4c684
Show file tree
Hide file tree
Showing 5 changed files with 345 additions and 14 deletions.
3 changes: 2 additions & 1 deletion composer/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

"""Module for checkpointing API."""

from composer.checkpoint.state_dict import get_metadata_state_dict, get_model_state_dict
from composer.checkpoint.state_dict import get_metadata_state_dict, get_model_state_dict, get_optim_state_dict

__all__ = [
'get_model_state_dict',
'get_optim_state_dict',
'get_metadata_state_dict',
]
125 changes: 123 additions & 2 deletions composer/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import fnmatch
import logging
import sys
from typing import Any, Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union

import torch
from packaging import version
Expand All @@ -20,6 +20,8 @@

log = logging.getLogger(__name__)

__all__ = ['get_model_state_dict', 'get_optim_state_dict']


def get_model_state_dict(
model: Union[ComposerModel, nn.Module],
Expand Down Expand Up @@ -89,7 +91,7 @@ def get_model_state_dict(
return model_state_dict


def _cast_state_dict_to_precision(state_dict: dict[str, Any], precision: Union[str, torch.dtype]):
def _cast_state_dict_to_precision(state_dict: Dict[str, Any], precision: Union[str, torch.dtype]) -> Dict[str, Any]:
if isinstance(precision, str):
precision = STR_TO_DTYPE[precision]

Expand Down Expand Up @@ -156,6 +158,125 @@ def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_st
return model_state_dict


def _get_optim_state_dict_with_fsdp_context_manager(
model: nn.Module,
optimizer: torch.optim.Optimizer,
sharded_state_dict: bool,
cpu_offload: bool,
) -> Dict[str, Any]:
"""Get the optimizer state dict with the FSDP context manager.
Args:
model: The model containing the parameters that the optimizer is optimizing.
optimizer: The optimizer to get the state dict from.
sharded_state_dict: Whether the optimizer state dict should be sharded or not. If True, every rank returns the state dict of its shards.
If False, then rank 0 returns the state dict of the entire optimizer.
cpu_offload: Whether to offload the state dict to CPU.
Returns:
The state dict of the optimizer.
"""
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
FullStateDictConfig,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
StateDictType,
)
state_dict_type = StateDictType.SHARDED_STATE_DICT if sharded_state_dict else StateDictType.FULL_STATE_DICT

state_dict_config = ShardedStateDictConfig(offload_to_cpu=cpu_offload,
) if sharded_state_dict else FullStateDictConfig(
rank0_only=True,
offload_to_cpu=cpu_offload,
)
optim_state_dict_config = ShardedOptimStateDictConfig(
offload_to_cpu=cpu_offload,
) if sharded_state_dict else FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=cpu_offload)
with FSDP.state_dict_type(
model,
state_dict_type=state_dict_type,
state_dict_config=state_dict_config,
optim_state_dict_config=optim_state_dict_config,
):
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
return optim_state_dict


def get_optim_state_dict(
model: Union[ComposerModel, nn.Module],
optimizer: torch.optim.Optimizer,
sharded_state_dict: bool = False,
precision: str = 'fp32',
include_keys: Optional[Union[str, Sequence[str]]] = None,
ignore_keys: Optional[Union[str, Sequence[str]]] = None,
cpu_offload: Optional[bool] = None,
) -> Dict[str, Any]:
"""Generate the state dict of the optimizer.
Args:
model: The model containing the parameters that the optimizer is optimizing.
optimizer: The optimizer to get the state dict from.
sharded: Whether the optimizer is sharded or not. If True, every rank returns the state dict of its shards.
If False, then rank 0 returns the state dict of the entire optimizer and all other ranks return an empty dict.
precision: The precision of the optimizer.
include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None.
ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None.
cpu_offload: Whether to offload the state dict to CPU. If None, it is set to True if FSDP is enabled with non-sharded state dict and False otherwise.
Returns:
The state dict of the optimizer.
"""
if include_keys is not None and ignore_keys is not None:
raise ValueError(f'Both {include_keys=} and {ignore_keys=} cannot be non-None.')

is_fsdp = _is_model_fsdp(model)
if not is_fsdp and sharded_state_dict:
raise ValueError('Sharded optim state dict can only be generated for FSDP models.')

cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict)
log.debug('Extracting optim state dict')
if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
log.debug('Calling torch get_optimizer_state_dict...')
optim_state_dict: Dict[str, Any] = get_optimizer_state_dict(
model=model,
optimizers=optimizer,
submodules=None, # We extract submodules below
options=StateDictOptions(
full_state_dict=not sharded_state_dict,
cpu_offload=cpu_offload,
),
)
else:
if is_fsdp:
log.debug('Calling legacy FSDP context manager to get optim state dict...')
optim_state_dict = _get_optim_state_dict_with_fsdp_context_manager(
model,
optimizer,
sharded_state_dict,
cpu_offload,
)
else:
optim_state_dict = optimizer.state_dict()

# For sharded models with non-sharded state dicts, only rank 0 has the full state dict including all the keys
target_state_dict_on_this_rank = (not sharded_state_dict and dist.get_global_rank() == 0) or sharded_state_dict

if target_state_dict_on_this_rank:
if ignore_keys is not None:
raise NotImplementedError('Ignoring keys in the optimizer state dict is not supported yet.')
if include_keys is not None:
raise NotImplementedError('Ignoring keys in the optimizer state dict is not supported yet.')

# param_key := index (0,1,2,..., len(model.parameters())-1) for unsharded models.
# param_key := fqn for sharded models.
for param_key, param_state_dict in optim_state_dict['state'].items():
optim_state_dict['state'][param_key] = _cast_state_dict_to_precision(param_state_dict, precision)
return optim_state_dict


def get_metadata_state_dict(
model: Optional[Union[ComposerModel, nn.Module]] = None,
sharded_state_dict: Optional[bool] = None,
Expand Down
197 changes: 195 additions & 2 deletions tests/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from typing import Any, Dict

import pytest
import torch
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import adam

from composer.checkpoint import get_metadata_state_dict, get_model_state_dict
from composer.checkpoint import get_metadata_state_dict, get_model_state_dict, get_optim_state_dict
from composer.devices import DeviceGPU
from composer.utils import dist
from tests.common.compare import deep_compare
Expand Down Expand Up @@ -233,6 +234,198 @@ def test_get_model_state_dict_precision_unsharded_model(precision: str, use_comp
assert tens.dtype == precision


def _init_model_and_optimizer(
use_composer_model: bool,
num_classes=3,
batch_size=5,
num_features=8,
take_step=True,
use_fsdp=False,
tensor_type='sharded_tensor',
):
model, loss_fn = _init_model(
use_composer_model,
num_classes=num_classes,
batch_size=batch_size,
num_features=num_features,
use_fsdp=use_fsdp,
tensor_type=tensor_type,
)

optimizer = _init_optimizer(
model,
loss_fn,
use_composer_model=use_composer_model,
num_classes=num_classes,
batch_size=batch_size,
num_features=num_features,
take_step=take_step,
)

return model, optimizer


def _init_model(
use_composer_model: bool = False,
num_classes=3,
batch_size=5,
num_features=8,
use_fsdp=False,
tensor_type='sharded_tensor',
):
if use_composer_model:
model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device='cuda')
loss_fn = model._loss_fn
else:
model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device='cuda')
loss_fn = torch.nn.CrossEntropyLoss()

if use_fsdp:
fsdp_kwargs: Dict[str, Any] = dict(
use_orig_params=True,
sync_module_states=True, # To enable easy comparison between rank 0 unsharded model and full state dict
)

if tensor_type == 'dtensor':
from torch.distributed.device_mesh import init_device_mesh
device_mesh = init_device_mesh('cuda', (2,))
fsdp_kwargs['device_mesh'] = device_mesh

model = FSDP(
model,
**fsdp_kwargs,
)

return model, loss_fn


def _init_optimizer(
model,
loss_fn,
use_composer_model: bool = False,
num_classes=3,
batch_size=5,
num_features=8,
take_step=True,
):
inputs = torch.randn(batch_size, num_features, device='cuda')
targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device='cuda', dtype=torch.long)
batch = (inputs, targets) if use_composer_model else inputs
optimizer = adam.Adam(model.parameters())
outputs = model(batch)
loss = loss_fn(outputs, targets)
loss.backward()
if take_step:
optimizer.step()
return optimizer


@pytest.mark.gpu
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_optim_state_dict_unsharded_model(use_composer_model: bool):
model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True)
optim_state_dict = get_optim_state_dict(model, optimizer)

# Dict mapping parameter index to optimizer state for that parameter.
osd_state = optim_state_dict['state']
# Dict mapping parameter itself to optimizer state for that parameter.
optim_state = optimizer.state

# Make sure optimizer state is the same between the state dict and the optimizer object.
for osd_param_state, opt_param_state in zip(osd_state.values(), optim_state.values()):
deep_compare(osd_param_state, opt_param_state)

# Make sure the optimizer state in the state dict is the same shape as the parameter it corresponds to.
# Because model is unsharded the optimizer state should have keys corresponding to the index of the model's parameters.
# e.g. if the model has 3 parameters, the optimizer state dict keys would be (0,1,2).
params = list(model.parameters())
param_dict = dict(list(model.named_parameters()))
for param_key, param_state in osd_state.items():
if isinstance(param_key, str):
param = param_dict[param_key]
else:
param = params[param_key]
assert param.shape == param_state['exp_avg'].shape
assert param.shape == param_state['exp_avg_sq'].shape

# Make sure param groups between the state dict and the optimizer object are the same.
for osd_group, opt_group in zip(optim_state_dict['param_groups'], optimizer.param_groups):
# Only params should differ between the two.
# * in the optimizer state dict params will be indices into the model's parameters list.
# * in the optimizer object params will be the actual parameter tensors.
deep_compare(osd_group, opt_group, ignore_keys=['params'])


@pytest.mark.gpu
@pytest.mark.parametrize(
'precision',
[
torch.float32,
torch.float16,
torch.bfloat16,
],
)
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_optim_state_dict_precision_unsharded_model(precision: str, use_composer_model: bool):
model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True)
optim_state_dict = get_optim_state_dict(model, optimizer, precision=precision)
for param_state in optim_state_dict['state'].values():
assert param_state['exp_avg'].dtype == precision
assert param_state['exp_avg_sq'].dtype == precision


@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor'])
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_optim_dict_full_for_sharded_model(world_size, tensor_type, use_composer_model: bool):
if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'):
pytest.skip('DTensor is only supported in PyTorch >= 2.2.0')

model, optimizer = _init_model_and_optimizer(
use_composer_model=use_composer_model,
take_step=True,
use_fsdp=True,
tensor_type=tensor_type,
)
optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=False)

with FSDP.summon_full_params(model):
# Make sure the optimizer state in the state dict is the same shape as the parameter it corresponds to.
fqn_to_shape_map = {fqn: param.shape for fqn, param in model.named_parameters()}
if dist.get_global_rank() == 0:
# Because model is sharded, the state dict should have the same keys as the model's parameters.
for fqn, param_state in optim_state_dict['state'].items():
model_param_shape = fqn_to_shape_map[fqn]
assert model_param_shape == param_state['exp_avg'].shape
assert model_param_shape == param_state['exp_avg_sq'].shape


@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor'])
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_optim_dict_sharded_for_sharded_model(world_size, tensor_type, use_composer_model: bool):
if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'):
pytest.skip('DTensor is only supported in PyTorch >= 2.2.0')

model, optimizer = _init_model_and_optimizer(
use_composer_model=use_composer_model,
take_step=True,
use_fsdp=True,
tensor_type=tensor_type,
)
model_state_dict = get_model_state_dict(model, sharded_state_dict=True)
optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=True)

# Check to make sure on every rank optimizer state name and shape matches model's
fqn_to_shape_map = {fqn: param.shape for fqn, param in model_state_dict.items()}
for fqn, param_state in optim_state_dict['state'].items():
model_param_shape = fqn_to_shape_map[fqn]
assert model_param_shape == param_state['exp_avg'].shape
assert model_param_shape == param_state['exp_avg_sq'].shape


@pytest.mark.gpu
@world_size(1, 2)
def test_get_metadata_empty_call(world_size):
Expand Down
Loading

0 comments on commit 8b4c684

Please sign in to comment.