Skip to content

Commit

Permalink
Tensor Parallelism Integration (#3269)
Browse files Browse the repository at this point in the history
* v1

* add test

* fix test

* v1

* some lint

* lint

* lint and pytorch pinning

* results

* tweak warnings

* fix lint

* lint

* filter

* add ckpt

* checkdown

* fix tests

* fix

* fix test

* fix tests

* fix tests

* paralleli config

* fix some arg parsing

* rename to parallelism config

* lint

* fix docs

* fix edge case

* lint

* log

* change slicing

* fix patching

* fix core

* lint

* clean up v1

* lint

* shallow copy

* add checks

* device mesh

* fix type checking

* fix bugs

* fix tests

* rename variables and fix checkpointing

* lint

* lint

* v1 refacotr

* lint

* tests

* add enum

* fix test

* fix

* fix fsdp submesh order

* fix tests

* change to world size

* add docs v1

* fix docs

* add tests

* v1 of tp test

* fix lint

* fix arg prop

* fix tests

* lint

* brians comments

* pr review

* add more gating

* lint

* fix

* fix lint

* fix lint

* lint

* force to run

* fix assert

* parallelism config

* tweak some parallelism issues

* remove assert

---------

Co-authored-by: Your Name <[email protected]>
  • Loading branch information
mvpatel2000 and Your Name authored May 24, 2024
1 parent 47dbf9f commit 09f14f9
Show file tree
Hide file tree
Showing 32 changed files with 1,133 additions and 822 deletions.
5 changes: 3 additions & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,9 @@ def _save_checkpoint(self, state: State, logger: Logger):
is_deepspeed,
keep_placeholders=True,
).lstrip('/')
assert state.sharded_ckpt_prefix_dir is not None
remote_prefix = state.sharded_ckpt_prefix_dir
assert state.fsdp_config is not None
remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir']
assert remote_prefix is not None
ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename)
remote_file_name = format_name_with_dist_and_time(remote_file_name, state.run_name, state.timestamp)
Expand Down
211 changes: 154 additions & 57 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn.modules.utils
from packaging import version
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
Expand All @@ -30,8 +31,6 @@
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric

from composer.utils.warnings import VersionedDeprecationWarning

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import GradScaler # type: ignore
else:
Expand All @@ -44,6 +43,8 @@
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time
from composer.devices import Device
from composer.utils import (
ParallelismType,
VersionedDeprecationWarning,
batch_get,
batch_set,
dist,
Expand Down Expand Up @@ -194,6 +195,75 @@ def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]):
return state


def _create_device_mesh(
device: Device,
fsdp_config: Optional[Dict[str, Any]],
tp_config: Optional[Dict[str, Any]],
) -> Optional[DeviceMesh]:
if fsdp_config is None:
return None

# Gather dimensions and names for the device mesh
dims: List[int] = []
names: List[str] = []
if fsdp_config['data_parallel_replicate_degree'] != 1:
dims.append(fsdp_config['data_parallel_replicate_degree'])
names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
dims.append(fsdp_config['data_parallel_shard_degree'])
names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
if tp_config is not None:
dims.append(tp_config['tensor_parallel_degree'])
names.append(ParallelismType.TENSOR_PARALLEL.value)

# Fill in the unspecified dimensions
product_of_dims = 1
unspecified_dim_names = []
for dim, name in zip(dims, names):
if dim != -1:
product_of_dims *= dim
else:
unspecified_dim_names.append(name)
if len(unspecified_dim_names) > 1:
raise ValueError(
f'Found multiple parallelism dimensions with -1: {unspecified_dim_names}. '
'Only one is allowed, which is set to fill the remaining dimensions.',
)
elif len(unspecified_dim_names) == 1:
if product_of_dims > dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} is greater than the product of the specified parallelism degrees '
f'{product_of_dims}. Please ensure the product of the specified parallelism degrees matches the world ',
f'size. Currently specified degrees are {names=}, {dims=}. One dimension can also be left as -1, which '
'will automatically be specified to ensure the product matches the world size.',
)
remaining_dimension = dist.get_world_size() // product_of_dims
if remaining_dimension * product_of_dims != dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} is not divisible by the product of the specified '
'parallelism degrees. Please ensure the product of the specified parallelism degrees '
'matches the world size.',
)
for i, dim in enumerate(dims):
if dim == -1:
dims[i] = remaining_dimension
log.info(f'Automatically setting {names[i]} to have parallelization degree {remaining_dimension}.')
break
else:
if product_of_dims != dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} does not equal the product of the specified parallelism degrees '
f'{product_of_dims}. Please ensure the product of the specified parallelism degrees matches the world ',
f'size. Currently specified degrees are {names=}, {dims=}. One dimension can also be left as -1, which '
'will automatically be specified to ensure the product matches the world size.',
)

device_type = device.name
if device_type == 'gpu':
device_type = 'cuda'

return init_device_mesh(device_type=device_type, mesh_shape=tuple(dims), mesh_dim_names=tuple(names))


_STATE_DICT_SERIALIZED_ATTRIBUTES = [
# List of attributes that are serialized with state_dict
# Only the attributes listed in state.serialized_attributes will actually be saved.
Expand Down Expand Up @@ -255,8 +325,7 @@ class State(Serializable):
algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training.
callbacks (Callback | Sequence[Callback], optional): The callbacks used for training.
deepspeed_config (Dict[str, Any], optional): The configuration dictionary for deepspeed.
fsdp_config (Dict[str, Any], optional): The configuration dictionary for FSDP.
fsdp_auto_wrap (bool, optional): Whether to automatically wrap the model with FSDP.
parallelism_config (Dict[str, Any], optional): The configuration dictionary for parallelism.
Attributes:
batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a
Expand Down Expand Up @@ -423,8 +492,7 @@ def __init__(

# Distributed training configs
deepspeed_config: Optional[Dict[str, Any]] = None,
fsdp_config: Optional[Dict[str, Any]] = None,
fsdp_auto_wrap: bool = True,
parallelism_config: Optional[Dict[str, Any]] = None,
):
self.rank_zero_seed = rank_zero_seed
self.model = model
Expand Down Expand Up @@ -468,20 +536,88 @@ def __init__(
self.profiler: Optional[Profiler] = None

self.deepspeed_config = deepspeed_config
self.fsdp_config = fsdp_config
self.fsdp_auto_wrap = fsdp_auto_wrap
parallelism_config = parallelism_config or {}
self.fsdp_config = parallelism_config.get('fsdp', None)
self.tp_config = parallelism_config.get('tp', None)

self._validate_parallelism_configs()

self.device_mesh: Optional[DeviceMesh] = _create_device_mesh(self.device, self.fsdp_config, self.tp_config)
if self.fsdp_config is not None and self.device_mesh is not None:
fsdp_mesh_dim_names = []
if self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names:
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
self.fsdp_config['device_mesh'] = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore
if self.tp_config is not None and self.device_mesh is not None:
self.tp_config['device_mesh'] = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value]

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()

# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
# All other attributes will not be serialized.
# For simplicity, omit the leading underscore for private attributes.
# For example, even though the optimizers are stored on the state
# as the "_optimizers" attribute, here we specify just "optimizers"
self.serialized_attributes = [
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
'rank_zero_seed',
'train_metrics',
'eval_metrics',
'run_name',
'dataset_state',
]

self.train_metrics: Optional[Dict[str, Metric]] = {}
self.eval_metrics: Dict[str, Dict[str, Metric]] = {}
self.train_metric_values: Dict[str, float] = {}
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}

self.metric_outputs: Dict[str, Any] = {}

def _validate_parallelism_configs(self):
# Validate TP config
if self.tp_config is not None:
warnings.warn('Tensor parallelism (TP) is experimental and may change in future versions.', FutureWarning)
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
raise ValueError('Tensor parallelism (TP) requires torch>=2.3.0.')
if self.fsdp_config is None:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP to be enabled. '
'An empty `fsdp_config` can be specified to enable FSDP with '
'default settings. Additionally, PyTorch currently errors if FSDP '
'data_parallel_shard_degree is not at least 2.',
)
if not self.fsdp_config['use_orig_params']:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, '
'which is the default and recommended setting.',
)

# Load monolith rank0 only
if self.load_monolith_rank0_only:
assert fsdp_config is not None
if self.tp_config is not None:
raise ValueError('load_fsdp_monolith_rank0_only is not compatible with tensor parallelism (TP).')
assert self.fsdp_config is not None
error_message = ''
if fsdp_config['sync_module_states'] == False:
if self.fsdp_config['sync_module_states'] == False:
error_message += textwrap.dedent(
"load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. "
"Either set fsdp_config['sync_module_states'] = True or set load_monolith_rank0_only = False. ",
)
# Broadcast rank 0 meta check to all ranks so error can be raised on all ranks
rank0_on_meta = 0
if dist.get_global_rank() == 0 and next(model.parameters()).device.type == 'meta':
if dist.get_global_rank() == 0 and next(self.model.parameters()).device.type == 'meta':
rank0_on_meta = 1
rank0_on_meta_tensor = self.device.tensor_to_device(torch.tensor([rank0_on_meta], dtype=torch.uint8))
dist.all_reduce(rank0_on_meta_tensor, reduce_operation='MAX')
Expand All @@ -494,10 +630,7 @@ def __init__(
if error_message != '':
raise ValueError(error_message)

self.sharded_ckpt_prefix_dir: Optional[str] = None
if self.fsdp_config is not None:
self.sharded_ckpt_prefix_dir = self.fsdp_config['sharded_ckpt_prefix_dir']

# Validate FSDP state dict type
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
if self.fsdp_state_dict_type == 'local':
raise ValueError(
Expand All @@ -521,39 +654,6 @@ def __init__(
),
)

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()

# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
# All other attributes will not be serialized.
# For simplicity, omit the leading underscore for private attributes.
# For example, even though the optimizers are stored on the state
# as the "_optimizers" attribute, here we specify just "optimizers"
self.serialized_attributes = [
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
'rank_zero_seed',
'train_metrics',
'eval_metrics',
'run_name',
'dataset_state',
]

self.train_metrics: Optional[Dict[str, Metric]] = {}
self.eval_metrics: Dict[str, Dict[str, Metric]] = {}
self.train_metric_values: Dict[str, float] = {}
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}

self.metric_outputs: Dict[str, Any] = {}

def _dataset_of(self, dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]:
"""Get the dataset contained by the given dataloader-like object.
Expand Down Expand Up @@ -794,12 +894,8 @@ def fsdp_sharded_state_dict_enabled(self):

@property
def fsdp_device_mesh(self):
if self.fsdp_enabled:
if not hasattr(self.model, 'model') or not hasattr(self.model.model, '_device_mesh'):
return None
return self.model.model._device_mesh
else:
return None
warnings.warn(VersionedDeprecationWarning('fsdp_device_mesh is deprecated. Use device_mesh instead.', '0.24'))
return self.device_mesh

@property
def load_fsdp_monolith_rank0_only(self):
Expand All @@ -814,8 +910,8 @@ def load_fsdp_monolith_rank0_only(self):
@property
def load_monolith_rank0_only(self):
return (
self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config['state_dict_type'] == 'full' and
self.fsdp_config['load_monolith_rank0_only'] == True
self.fsdp_config is not None and self.fsdp_config['auto_wrap'] and
self.fsdp_config['state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True
)

def _get_integrations_state_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1289,8 +1385,9 @@ def load_model_state(
if self.load_monolith_rank0_only:
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
with reproducibility.seed_context(self.rank_zero_seed):
from composer.distributed import prepare_fsdp_module

prepare_fsdp_module(
self.model,
self.optimizers,
Expand Down
25 changes: 25 additions & 0 deletions composer/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Distributed training."""

from composer.distributed.deepspeed import fix_batch_precision_for_deepspeed, parse_deepspeed_config
from composer.distributed.dist_strategy import (
DDPSyncStrategy,
ddp_sync_context,
prepare_ddp_module,
prepare_fsdp_module,
prepare_tp_module,
)
from composer.distributed.mosaic_fsdp import set_fsdp_default

__all__ = [
'fix_batch_precision_for_deepspeed',
'parse_deepspeed_config',
'DDPSyncStrategy',
'ddp_sync_context',
'prepare_ddp_module',
'prepare_fsdp_module',
'prepare_tp_module',
'set_fsdp_default',
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from composer.core import Batch, Precision, State
from composer.utils import dist, map_collection

__all__ = ['_fix_batch_precision_for_deepspeed', '_parse_deepspeed_config']
__all__ = ['fix_batch_precision_for_deepspeed', 'parse_deepspeed_config']


def _add_batch_config(config: Dict[str, Any], state: State):
Expand Down Expand Up @@ -105,7 +105,7 @@ def _add_precision_config(config: Dict[str, Any], state: State):
config['bf16'] = cast(Dict[str, Any], {'enabled': True})


def _parse_deepspeed_config(
def parse_deepspeed_config(
config: Dict[str, Any],
state: State,
) -> Dict[str, Any]:
Expand Down Expand Up @@ -160,7 +160,7 @@ def _convert_fp32_tensor_to_bf16(tensor: torch.Tensor):
return tensor


def _fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch:
def fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch:
"""Ensures that a batch is properly formatted for DeepSpeed precisions, if active.
.. note:: Just because the precision is set to FP16 doesn't mean the entire batch can
Expand Down
Loading

0 comments on commit 09f14f9

Please sign in to comment.