Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ckpt-rewr] Get Model State Dict Util Function #3250

Merged
merged 40 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
17d80d5
first commit get model sd
eracah May 3, 2024
3bb5785
Add test stub
eracah May 3, 2024
1ebbd87
Add test stub
eracah May 6, 2024
cca228b
Added unit tests for non-sharded use cases
eracah May 7, 2024
37e9753
Add support for ComposerModel (+ Precision tweak)
eracah May 7, 2024
4558182
Add sharded state dict support
eracah May 8, 2024
23434c0
add precision test
eracah May 8, 2024
eb45d05
Add precision test for sharded state dict
eracah May 8, 2024
f1acf26
pre-commit
eracah May 10, 2024
299aae5
pre-commit
eracah May 10, 2024
5028039
Merge branch 'dev' into get-model-sd
eracah May 10, 2024
c08abd1
mark gpu
eracah May 10, 2024
b4f380c
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 10, 2024
a3664ca
pre-commit
eracah May 11, 2024
8ab34f1
add error for sharded + non-fsdp
eracah May 15, 2024
ab00b11
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 15, 2024
2bbdc27
change gating
eracah May 15, 2024
5cf52cb
got tests to pass
eracah May 15, 2024
da14580
fix tests
eracah May 15, 2024
f20ecb3
Merge branch 'dev' into get-model-sd
eracah May 15, 2024
24aecbf
Merge branch 'dev' into get-model-sd
eracah May 16, 2024
fae7746
docstring
eracah May 16, 2024
ff8d91c
Update composer/checkpoint/state_dict.py
eracah May 16, 2024
aba4900
Update composer/checkpoint/state_dict.py
eracah May 16, 2024
1975a56
pc
eracah May 16, 2024
d72e2d4
fix version
eracah May 16, 2024
f5e4f53
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 16, 2024
f3a7cce
pc
eracah May 16, 2024
50c2308
pc
eracah May 16, 2024
7fec0c2
pre-commit
eracah May 17, 2024
2e3bf04
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 17, 2024
2941373
Addressed some comments
eracah May 17, 2024
307c780
pre-commit
eracah May 17, 2024
11bdd89
add comments for new simple models
eracah May 17, 2024
677002b
remove todo's
eracah May 17, 2024
01c1560
remove docstring arg
eracah May 17, 2024
5ccbe12
change scope name for get_model_state_dict
eracah May 17, 2024
44468db
Merge branch 'get-model-sd' of https://github.com/eracah/evan-compose…
eracah May 17, 2024
aee42ff
pre-commit
eracah May 17, 2024
94ab3da
Merge branch 'dev' into get-model-sd
eracah May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions composer/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from composer.checkpoint.state_dict import get_model_state_dict

__all__ = [
'get_model_state_dict',
]
159 changes: 159 additions & 0 deletions composer/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Useful functions for generating state dicts and manipulating them."""

import fnmatch
import logging
from typing import Any, Dict, Optional, Sequence, Union

import torch
from packaging import version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel

from composer.models import ComposerModel
from composer.utils import dist

log = logging.getLogger(__name__)


def get_model_state_dict(
model: Union[ComposerModel, nn.Module],
sharded: bool,
eracah marked this conversation as resolved.
Show resolved Hide resolved
precision: Union[str, torch.dtype] = 'fp32',
include_keys: Optional[Union[str, Sequence[str]]] = None,
ignore_keys: Optional[Union[str, Sequence[str]]] = None,
cpu_offload: Optional[bool] = None,
) -> Dict[str, Any]:
"""Generate the state dict of the model.

Args:
model: The model to get the state dict from.
sharded: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards.
If False, then rank 0 returns the state dict of the entire model.
eracah marked this conversation as resolved.
Show resolved Hide resolved
precision: The precision of the model. Can be specified as a string ('fp32', 'fp16', 'bf16') or a torch.dtype.
include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None.
ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None.
cpu_offload: Whether to offload the state dict to CPU. If None, it is set to True if FSDP is enabled, False otherwise.

Returns:
The state dict of the model.
"""
if include_keys is not None and ignore_keys is not None:
raise ValueError('Both include_keys and ignore_keys cannot be non-None.')
eracah marked this conversation as resolved.
Show resolved Hide resolved

is_fsdp = _is_model_fsdp(model)
if not is_fsdp and sharded:
raise ValueError('Sharded state dict can only be generated for FSDP models.')
cpu_offload = cpu_offload if cpu_offload is not None else is_fsdp
eracah marked this conversation as resolved.
Show resolved Hide resolved

log.debug('Extracting model state dict')
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions
from torch.distributed.checkpoint.state_dict import get_model_state_dict as dcp_get_model_state_dict
eracah marked this conversation as resolved.
Show resolved Hide resolved

get_nonsharded_state_dict = not sharded
eracah marked this conversation as resolved.
Show resolved Hide resolved

log.debug('Calling torch get_model_state_dict...')
model_state_dict = torch_get_model_state_dict(
model=model,
submodules=None, # We will handle extracting submodules ourselves down below.
eracah marked this conversation as resolved.
Show resolved Hide resolved
options=StateDictOptions(
full_state_dict=use_unsharded_state_dict,
cpu_offload=cpu_offload,
),
)
else:
if is_fsdp:
log.debug('Calling legacy FSDP context manager to get model state dict...')
model_state_dict = _get_model_state_dict_with_fsdp_context_manager(model, sharded)
else:
log.debug('Calling model.state_dict() for non-FSDP model...')
model_state_dict = model.state_dict()
if isinstance(model, DistributedDataParallel):
nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.')

if include_keys is not None:
model_state_dict = _extract_keys_from_state_dict(model_state_dict, include_keys)

if ignore_keys is not None:
model_state_dict = _remove_keys_from_state_dict(model_state_dict, ignore_keys)

model_state_dict = _cast_state_dict_to_precision(state_dict=model_state_dict, precision=precision)

log.debug('Finished extracting model state dict')
return model_state_dict


STR_TO_DTYPE = {
'fp32': torch.float32,
'fp16': torch.float16,
'bf16': torch.bfloat16,
}
eracah marked this conversation as resolved.
Show resolved Hide resolved


def _cast_state_dict_to_precision(state_dict: Dict[str, Any], precision: Union[str, torch.dtype]):
if isinstance(precision, str):
precision = STR_TO_DTYPE[precision]

new_state_dict = {k: v.to(precision) for k, v in state_dict.items()}
return new_state_dict


def _extract_keys_from_state_dict(state_dict: Dict[str, Any], include_keys: Union[str, Sequence[str]]):
if isinstance(include_keys, str):
include_keys = [include_keys]
new_state_dict = {k: v for k, v in state_dict.items() if any(fnmatch.fnmatch(k, key) for key in include_keys)}

return new_state_dict


def _remove_keys_from_state_dict(state_dict: Dict[str, Any], ignore_keys: Union[str, Sequence[str]]):
if isinstance(ignore_keys, str):
ignore_keys = [ignore_keys]
new_state_dict = {k: v for k, v in state_dict.items() if not any(fnmatch.fnmatch(k, key) for key in ignore_keys)}
return new_state_dict


def _is_model_fsdp(model) -> bool:
"""Indicates if FSDP is enabled.

Args:
model: The model to check if FSDP is enabled.

Returns:
True if FSDP is enabled, False otherwise.

"""
for module in model.modules():
if isinstance(module, FSDP):
return True
return False


def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded: bool) -> Dict[str, Any]:
"""Get the model state dict with the FSDP context manager.

Args:
model: The model to get the state dict from.
sharded: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards.
If False, then rank 0 returns the state dict of the entire model.

Returns:
The state dict of the model.
"""
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullStateDictConfig,
ShardedStateDictConfig,
StateDictType,
)
state_dict_type = StateDictType.SHARDED_STATE_DICT if sharded else StateDictType.FULL_STATE_DICT
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True,) if sharded else FullStateDictConfig(
rank0_only=True,
offload_to_cpu=True,
)
with FSDP.state_dict_type(model, state_dict_type=state_dict_type, state_dict_config=state_dict_config):
model_state_dict = model.state_dict()
return model_state_dict
202 changes: 202 additions & 0 deletions tests/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from torch.distributed.device_mesh import init_device_mesh
eracah marked this conversation as resolved.
Show resolved Hide resolved
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from composer.checkpoint import get_model_state_dict
eracah marked this conversation as resolved.
Show resolved Hide resolved
from composer.utils import dist
from tests.common.compare import deep_compare
from tests.common.markers import world_size
from tests.common.models import EvenSimplerMLP, SimpleComposerMLP


@pytest.mark.gpu
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_model_state_dict_unsharded_model(use_composer_model: bool):
if use_composer_model:
model = SimpleComposerMLP(num_features=8, device='cuda')
else:
model = EvenSimplerMLP(num_features=8, device='cuda')
model_state_dict = get_model_state_dict(model, sharded=False, include_keys=None, ignore_keys=None)
for name, param in model.named_parameters():
print(name)
assert name in model_state_dict
assert torch.equal(model_state_dict[name], param)


@pytest.mark.gpu
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_model_state_dict_include(use_composer_model: bool):
if use_composer_model:
model = SimpleComposerMLP(num_features=8, device='cuda')
else:
model = EvenSimplerMLP(num_features=8, device='cuda')
model_state_dict = get_model_state_dict(model, sharded=False, include_keys=['module.0.weight'])
assert set(model_state_dict.keys()) == {'module.0.weight'}

model_state_dict = get_model_state_dict(model, sharded=False, include_keys='module.2*')
assert set(model_state_dict.keys()) == {'module.2.weight'}


@pytest.mark.gpu
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_model_state_dict_ignore(use_composer_model: bool):
if use_composer_model:
model = SimpleComposerMLP(num_features=8, device='cuda')
else:
model = EvenSimplerMLP(num_features=8, device='cuda')

model_state_dict = get_model_state_dict(model, sharded=False, ignore_keys='module.2.weight')
assert set(model_state_dict.keys()) == {'module.0.weight'}

model_state_dict = get_model_state_dict(model, sharded=False, ignore_keys=['module.2*'])
assert set(model_state_dict.keys()) == {'module.0.weight'}


#TODO add tests for sharded and for precision
@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor'])
@pytest.mark.parametrize('use_composer_model', [True, False])
eracah marked this conversation as resolved.
Show resolved Hide resolved
def test_get_model_state_dict_full_for_sharded_model(world_size, tensor_type, use_composer_model: bool):
if use_composer_model:
model = SimpleComposerMLP(num_features=16, device='cuda')
else:
model = EvenSimplerMLP(num_features=16, device='cuda')

# Torch flattens model params in place after wrapped with FSDP, so we need to cache unflattened params now
# before fsdp wrapping in order to keep pre-sharding shapes.
pre_shard_state_dict = get_model_state_dict(
model,
sharded=False,
cpu_offload=True, # Set this to True, so that both state dicts will be on cpu
)
device_mesh = init_device_mesh('cuda', (2,)) if tensor_type == 'dtensor' else None
sharded_model = FSDP(
model,
use_orig_params=True,
sync_module_states=True, # We set this to enable easy comparison between rank 0 unsharded model and full state dict
device_mesh=device_mesh,
)

post_shard_full_state_dict = get_model_state_dict(sharded_model, sharded=False)

if dist.get_global_rank() == 0:
deep_compare(pre_shard_state_dict, post_shard_full_state_dict)


@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor'])
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_model_state_dict_sharded(world_size, tensor_type, use_composer_model: bool):
if use_composer_model:
model = SimpleComposerMLP(num_features=16, device='cuda')
else:
model = EvenSimplerMLP(num_features=16, device='cuda')

# Torch flattens model params in place after wrapped with FSDP, so we need to cache unflattened params now
# before fsdp wrapping in order to keep pre-sharding shapes.
pre_shard_full_state_dict = get_model_state_dict(
model,
sharded=False,
cpu_offload=True, # Set this to True, so that both state dicts will be on cpu
)

device_mesh = init_device_mesh('cuda', (2,)) if tensor_type == 'dtensor' else None
sharded_model = FSDP(
model,
use_orig_params=True,
sync_module_states=True,
device_mesh=device_mesh,
)

post_shard_sharded_sd = get_model_state_dict(sharded_model, sharded=True)

# In order to test if the sharded state dict is correct we go through this process:
# 1. Transform the each rank's state dict's values by extracting the the local tensor from the ShardedTensor object
# 2. Gather each rank's state dicts
# 3. Make a "reconstructed" full state dict by, for each key, concatenating all the tensor shards into one big tensor
# 4. Compare this "reconstructed" full state dict to the original model's state dict to ensure they are the same.
local_tensor_sd = {
n: (p.local_tensor() if tensor_type == 'sharded_tensor' else p.to_local())
for n, p in post_shard_sharded_sd.items()
}
all_local_tensor_sd = dist.all_gather_object(local_tensor_sd)
post_shard_reconstructed_full_sd = {
n: torch.cat(
[sd[n] for sd in all_local_tensor_sd],
dim=0, # dim=0 because fsdp shards each tensor on the 0th dimension
) for n in pre_shard_full_state_dict.keys()
}
if dist.get_global_rank() == 0:
deep_compare(pre_shard_full_state_dict, post_shard_reconstructed_full_sd)


# TODO test precision
@pytest.mark.gpu
@pytest.mark.parametrize(
'precision',
[
torch.float32,
torch.float16,
torch.bfloat16,
],
)
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_model_state_dict_precision_unsharded_model(precision: str, use_composer_model: bool):
if use_composer_model:
model = SimpleComposerMLP(num_features=8, device='cuda')
else:
model = EvenSimplerMLP(num_features=8, device='cuda')
model_state_dict = get_model_state_dict(
model,
precision=precision,
sharded=False,
include_keys=None,
ignore_keys=None,
)
for tens in model_state_dict.values():
assert tens.dtype == precision


@world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize(
'precision',
[
torch.float32,
torch.float16,
torch.bfloat16,
],
)
@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor'])
@pytest.mark.parametrize('use_composer_model', [True, False])
def test_get_model_state_dict_precision_sharded_model(
world_size, tensor_type, precision: str, use_composer_model: bool
):
if use_composer_model:
model = SimpleComposerMLP(num_features=8, device='cuda')
else:
model = EvenSimplerMLP(num_features=8, device='cuda')

device_mesh = init_device_mesh('cuda', (2,)) if tensor_type == 'dtensor' else None
sharded_model = FSDP(
model,
use_orig_params=True,
sync_module_states=True,
device_mesh=device_mesh,
)
model_state_dict = get_model_state_dict(
sharded_model,
precision=precision,
sharded=True,
include_keys=None,
ignore_keys=None,
)
for sharded_tens in model_state_dict.values():
local_tensor = sharded_tens.local_tensor() if tensor_type == 'sharded_tensor' else sharded_tens.to_local()
assert local_tensor.dtype == precision
Loading
Loading