diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index 482b3862a1..03f43882d8 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -468,8 +468,9 @@ def _save_checkpoint(self, state: State, logger: Logger): is_deepspeed, keep_placeholders=True, ).lstrip('/') - assert state.sharded_ckpt_prefix_dir is not None - remote_prefix = state.sharded_ckpt_prefix_dir + assert state.fsdp_config is not None + remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir'] + assert remote_prefix is not None ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename) remote_file_name = format_name_with_dist_and_time(remote_file_name, state.run_name, state.timestamp) diff --git a/composer/core/state.py b/composer/core/state.py index 4506bdeec9..49e766f1e5 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -17,6 +17,7 @@ import torch import torch.nn.modules.utils from packaging import version +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullOptimStateDictConfig, @@ -30,8 +31,6 @@ from torch.utils.data import DataLoader, Dataset from torchmetrics import Metric -from composer.utils.warnings import VersionedDeprecationWarning - if version.parse(torch.__version__) >= version.parse('2.3.0'): from torch.amp.grad_scaler import GradScaler # type: ignore else: @@ -44,6 +43,8 @@ from composer.core.time import Time, Timestamp, TimeUnit, ensure_time from composer.devices import Device from composer.utils import ( + ParallelismType, + VersionedDeprecationWarning, batch_get, batch_set, dist, @@ -194,6 +195,75 @@ def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]): return state +def _create_device_mesh( + device: Device, + fsdp_config: Optional[Dict[str, Any]], + tp_config: Optional[Dict[str, Any]], +) -> Optional[DeviceMesh]: + if fsdp_config is None: + return None + + # Gather dimensions and names for the device mesh + dims: List[int] = [] + names: List[str] = [] + if fsdp_config['data_parallel_replicate_degree'] != 1: + dims.append(fsdp_config['data_parallel_replicate_degree']) + names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value) + dims.append(fsdp_config['data_parallel_shard_degree']) + names.append(ParallelismType.DATA_PARALLEL_SHARD.value) + if tp_config is not None: + dims.append(tp_config['tensor_parallel_degree']) + names.append(ParallelismType.TENSOR_PARALLEL.value) + + # Fill in the unspecified dimensions + product_of_dims = 1 + unspecified_dim_names = [] + for dim, name in zip(dims, names): + if dim != -1: + product_of_dims *= dim + else: + unspecified_dim_names.append(name) + if len(unspecified_dim_names) > 1: + raise ValueError( + f'Found multiple parallelism dimensions with -1: {unspecified_dim_names}. ' + 'Only one is allowed, which is set to fill the remaining dimensions.', + ) + elif len(unspecified_dim_names) == 1: + if product_of_dims > dist.get_world_size(): + raise ValueError( + f'World size {dist.get_world_size()} is greater than the product of the specified parallelism degrees ' + f'{product_of_dims}. Please ensure the product of the specified parallelism degrees matches the world ', + f'size. Currently specified degrees are {names=}, {dims=}. One dimension can also be left as -1, which ' + 'will automatically be specified to ensure the product matches the world size.', + ) + remaining_dimension = dist.get_world_size() // product_of_dims + if remaining_dimension * product_of_dims != dist.get_world_size(): + raise ValueError( + f'World size {dist.get_world_size()} is not divisible by the product of the specified ' + 'parallelism degrees. Please ensure the product of the specified parallelism degrees ' + 'matches the world size.', + ) + for i, dim in enumerate(dims): + if dim == -1: + dims[i] = remaining_dimension + log.info(f'Automatically setting {names[i]} to have parallelization degree {remaining_dimension}.') + break + else: + if product_of_dims != dist.get_world_size(): + raise ValueError( + f'World size {dist.get_world_size()} does not equal the product of the specified parallelism degrees ' + f'{product_of_dims}. Please ensure the product of the specified parallelism degrees matches the world ', + f'size. Currently specified degrees are {names=}, {dims=}. One dimension can also be left as -1, which ' + 'will automatically be specified to ensure the product matches the world size.', + ) + + device_type = device.name + if device_type == 'gpu': + device_type = 'cuda' + + return init_device_mesh(device_type=device_type, mesh_shape=tuple(dims), mesh_dim_names=tuple(names)) + + _STATE_DICT_SERIALIZED_ATTRIBUTES = [ # List of attributes that are serialized with state_dict # Only the attributes listed in state.serialized_attributes will actually be saved. @@ -255,8 +325,7 @@ class State(Serializable): algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training. callbacks (Callback | Sequence[Callback], optional): The callbacks used for training. deepspeed_config (Dict[str, Any], optional): The configuration dictionary for deepspeed. - fsdp_config (Dict[str, Any], optional): The configuration dictionary for FSDP. - fsdp_auto_wrap (bool, optional): Whether to automatically wrap the model with FSDP. + parallelism_config (Dict[str, Any], optional): The configuration dictionary for parallelism. Attributes: batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a @@ -423,8 +492,7 @@ def __init__( # Distributed training configs deepspeed_config: Optional[Dict[str, Any]] = None, - fsdp_config: Optional[Dict[str, Any]] = None, - fsdp_auto_wrap: bool = True, + parallelism_config: Optional[Dict[str, Any]] = None, ): self.rank_zero_seed = rank_zero_seed self.model = model @@ -468,20 +536,88 @@ def __init__( self.profiler: Optional[Profiler] = None self.deepspeed_config = deepspeed_config - self.fsdp_config = fsdp_config - self.fsdp_auto_wrap = fsdp_auto_wrap + parallelism_config = parallelism_config or {} + self.fsdp_config = parallelism_config.get('fsdp', None) + self.tp_config = parallelism_config.get('tp', None) + + self._validate_parallelism_configs() + + self.device_mesh: Optional[DeviceMesh] = _create_device_mesh(self.device, self.fsdp_config, self.tp_config) + if self.fsdp_config is not None and self.device_mesh is not None: + fsdp_mesh_dim_names = [] + if self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names: + fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value) + fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_SHARD.value) + self.fsdp_config['device_mesh'] = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore + if self.tp_config is not None and self.device_mesh is not None: + self.tp_config['device_mesh'] = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value] + + # Set defaults for transient variables (to make pyright happy) + self.batch: Any = None + self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor() + self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor() + + # These attributes will be serialized using .state_dict(), and loaded with .load_state_dict() + # All other attributes will not be serialized. + # For simplicity, omit the leading underscore for private attributes. + # For example, even though the optimizers are stored on the state + # as the "_optimizers" attribute, here we specify just "optimizers" + self.serialized_attributes = [ + 'model', + 'optimizers', + 'schedulers', + 'algorithms', + 'callbacks', + 'scaler', + 'timestamp', + 'rank_zero_seed', + 'train_metrics', + 'eval_metrics', + 'run_name', + 'dataset_state', + ] + self.train_metrics: Optional[Dict[str, Metric]] = {} + self.eval_metrics: Dict[str, Dict[str, Metric]] = {} + self.train_metric_values: Dict[str, float] = {} + self.eval_metric_values: Dict[str, float] = {} + self.total_loss_dict: Dict[str, float] = {} + + self.metric_outputs: Dict[str, Any] = {} + + def _validate_parallelism_configs(self): + # Validate TP config + if self.tp_config is not None: + warnings.warn('Tensor parallelism (TP) is experimental and may change in future versions.', FutureWarning) + if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'): + raise ValueError('Tensor parallelism (TP) requires torch>=2.3.0.') + if self.fsdp_config is None: + raise ValueError( + 'Tensor parallelism (TP) currently requires FSDP to be enabled. ' + 'An empty `fsdp_config` can be specified to enable FSDP with ' + 'default settings. Additionally, PyTorch currently errors if FSDP ' + 'data_parallel_shard_degree is not at least 2.', + ) + if not self.fsdp_config['use_orig_params']: + raise ValueError( + 'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, ' + 'which is the default and recommended setting.', + ) + + # Load monolith rank0 only if self.load_monolith_rank0_only: - assert fsdp_config is not None + if self.tp_config is not None: + raise ValueError('load_fsdp_monolith_rank0_only is not compatible with tensor parallelism (TP).') + assert self.fsdp_config is not None error_message = '' - if fsdp_config['sync_module_states'] == False: + if self.fsdp_config['sync_module_states'] == False: error_message += textwrap.dedent( "load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. " "Either set fsdp_config['sync_module_states'] = True or set load_monolith_rank0_only = False. ", ) # Broadcast rank 0 meta check to all ranks so error can be raised on all ranks rank0_on_meta = 0 - if dist.get_global_rank() == 0 and next(model.parameters()).device.type == 'meta': + if dist.get_global_rank() == 0 and next(self.model.parameters()).device.type == 'meta': rank0_on_meta = 1 rank0_on_meta_tensor = self.device.tensor_to_device(torch.tensor([rank0_on_meta], dtype=torch.uint8)) dist.all_reduce(rank0_on_meta_tensor, reduce_operation='MAX') @@ -494,10 +630,7 @@ def __init__( if error_message != '': raise ValueError(error_message) - self.sharded_ckpt_prefix_dir: Optional[str] = None - if self.fsdp_config is not None: - self.sharded_ckpt_prefix_dir = self.fsdp_config['sharded_ckpt_prefix_dir'] - + # Validate FSDP state dict type if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: if self.fsdp_state_dict_type == 'local': raise ValueError( @@ -521,39 +654,6 @@ def __init__( ), ) - # Set defaults for transient variables (to make pyright happy) - self.batch: Any = None - self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor() - self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor() - - # These attributes will be serialized using .state_dict(), and loaded with .load_state_dict() - # All other attributes will not be serialized. - # For simplicity, omit the leading underscore for private attributes. - # For example, even though the optimizers are stored on the state - # as the "_optimizers" attribute, here we specify just "optimizers" - self.serialized_attributes = [ - 'model', - 'optimizers', - 'schedulers', - 'algorithms', - 'callbacks', - 'scaler', - 'timestamp', - 'rank_zero_seed', - 'train_metrics', - 'eval_metrics', - 'run_name', - 'dataset_state', - ] - - self.train_metrics: Optional[Dict[str, Metric]] = {} - self.eval_metrics: Dict[str, Dict[str, Metric]] = {} - self.train_metric_values: Dict[str, float] = {} - self.eval_metric_values: Dict[str, float] = {} - self.total_loss_dict: Dict[str, float] = {} - - self.metric_outputs: Dict[str, Any] = {} - def _dataset_of(self, dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]: """Get the dataset contained by the given dataloader-like object. @@ -794,12 +894,8 @@ def fsdp_sharded_state_dict_enabled(self): @property def fsdp_device_mesh(self): - if self.fsdp_enabled: - if not hasattr(self.model, 'model') or not hasattr(self.model.model, '_device_mesh'): - return None - return self.model.model._device_mesh - else: - return None + warnings.warn(VersionedDeprecationWarning('fsdp_device_mesh is deprecated. Use device_mesh instead.', '0.24')) + return self.device_mesh @property def load_fsdp_monolith_rank0_only(self): @@ -814,8 +910,8 @@ def load_fsdp_monolith_rank0_only(self): @property def load_monolith_rank0_only(self): return ( - self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config['state_dict_type'] == 'full' and - self.fsdp_config['load_monolith_rank0_only'] == True + self.fsdp_config is not None and self.fsdp_config['auto_wrap'] and + self.fsdp_config['state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True ) def _get_integrations_state_dict(self) -> Dict[str, Any]: @@ -1289,8 +1385,9 @@ def load_model_state( if self.load_monolith_rank0_only: assert self.fsdp_config is not None log.info('Wrapping model with FSDP after loading model_state.') - from composer.trainer.dist_strategy import prepare_fsdp_module with reproducibility.seed_context(self.rank_zero_seed): + from composer.distributed import prepare_fsdp_module + prepare_fsdp_module( self.model, self.optimizers, diff --git a/composer/distributed/__init__.py b/composer/distributed/__init__.py new file mode 100644 index 0000000000..9a994dd762 --- /dev/null +++ b/composer/distributed/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Distributed training.""" + +from composer.distributed.deepspeed import fix_batch_precision_for_deepspeed, parse_deepspeed_config +from composer.distributed.dist_strategy import ( + DDPSyncStrategy, + ddp_sync_context, + prepare_ddp_module, + prepare_fsdp_module, + prepare_tp_module, +) +from composer.distributed.mosaic_fsdp import set_fsdp_default + +__all__ = [ + 'fix_batch_precision_for_deepspeed', + 'parse_deepspeed_config', + 'DDPSyncStrategy', + 'ddp_sync_context', + 'prepare_ddp_module', + 'prepare_fsdp_module', + 'prepare_tp_module', + 'set_fsdp_default', +] diff --git a/composer/trainer/_deepspeed.py b/composer/distributed/deepspeed.py similarity index 97% rename from composer/trainer/_deepspeed.py rename to composer/distributed/deepspeed.py index 38b1532052..5858ae4e0c 100644 --- a/composer/trainer/_deepspeed.py +++ b/composer/distributed/deepspeed.py @@ -13,7 +13,7 @@ from composer.core import Batch, Precision, State from composer.utils import dist, map_collection -__all__ = ['_fix_batch_precision_for_deepspeed', '_parse_deepspeed_config'] +__all__ = ['fix_batch_precision_for_deepspeed', 'parse_deepspeed_config'] def _add_batch_config(config: Dict[str, Any], state: State): @@ -105,7 +105,7 @@ def _add_precision_config(config: Dict[str, Any], state: State): config['bf16'] = cast(Dict[str, Any], {'enabled': True}) -def _parse_deepspeed_config( +def parse_deepspeed_config( config: Dict[str, Any], state: State, ) -> Dict[str, Any]: @@ -160,7 +160,7 @@ def _convert_fp32_tensor_to_bf16(tensor: torch.Tensor): return tensor -def _fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch: +def fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch: """Ensures that a batch is properly formatted for DeepSpeed precisions, if active. .. note:: Just because the precision is set to FP16 doesn't mean the entire batch can diff --git a/composer/trainer/dist_strategy.py b/composer/distributed/dist_strategy.py similarity index 92% rename from composer/trainer/dist_strategy.py rename to composer/distributed/dist_strategy.py index 8255a80f39..7ffc9f3c7f 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/distributed/dist_strategy.py @@ -23,18 +23,17 @@ from composer.core import Precision, State from composer.devices import Device -from composer.trainer.meta_safe_apply import meta_safe_apply -from composer.trainer.mosaic_fsdp import patch_pytorch -from composer.trainer.mosaic_fsdp_utils import ( +from composer.distributed.meta_safe_apply import meta_safe_apply +from composer.distributed.mosaic_fsdp import ( BACKWARD_PREFETCH_MAP, SHARDING_MAP, - _set_custom_fsdp_module_kwargs, get_cpu_offload, get_mixed_precision, + set_custom_fsdp_module_kwargs, ) from composer.utils import StringEnum, dist, ensure_tuple -__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module'] +__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module'] log = logging.getLogger(__name__) @@ -142,35 +141,6 @@ def prepare_ddp_module(module: torch.nn.Module, find_unused_parameters: bool) -> ) -def set_fsdp_default(fsdp_config: Dict[str, Any]): - """Modify fsdp_config to set default values for missing keys.""" - fsdp_config.setdefault('activation_checkpointing', False) - fsdp_config.setdefault('activation_checkpointing_reentrant', True) - fsdp_config.setdefault('activation_cpu_offload', False) - fsdp_config.setdefault('te_checkpoint_wrapper', False) - fsdp_config.setdefault('te_shard_fp8_weight', False) - fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST') - fsdp_config.setdefault('backward_prefetch_limit', 1) - fsdp_config.setdefault('cpu_offload', False) - fsdp_config.setdefault('forward_prefetch', False) - fsdp_config.setdefault('forward_prefetch_limit', 1) - fsdp_config.setdefault('ignored_modules', None) - fsdp_config.setdefault('keep_low_precision_grads', False) - fsdp_config.setdefault('limit_all_gathers', True) - fsdp_config.setdefault('load_monolith_rank0_only', False) - fsdp_config.setdefault('load_planner', None) - fsdp_config.setdefault('mixed_precision', 'DEFAULT') - fsdp_config.setdefault('process_group', None) - fsdp_config.setdefault('save_planner', None) - fsdp_config.setdefault('sharded_ckpt_prefix_dir', 'ep{epoch}-ba{batch}') - fsdp_config.setdefault('sharding_strategy', 'FULL_SHARD') - fsdp_config.setdefault('state_dict_type', 'full') - fsdp_config.setdefault('sync_module_states', False) - fsdp_config.setdefault('use_orig_params', True) - fsdp_config.setdefault('verbose', False) - return fsdp_config - - def _recreate_fsdp_param_groups_from_unwrapped_opt_info( fsdp_wrapped_named_params: Iterator[Tuple[str, torch.nn.Parameter]], non_wrapped_param_names_to_group_num: Dict[str, int], @@ -209,6 +179,22 @@ def _recreate_fsdp_param_groups_from_unwrapped_opt_info( return [group_num_to_optimizer_info[num] for num in sorted(group_num_to_optimizer_info.keys())] +def prepare_tp_module( + model: torch.nn.Module, + tp_config: Dict[str, Any], +) -> None: + """Prepare a module (assumed ComposerModel) for use with tensor parallel.""" + from torch.distributed.tensor.parallel import parallelize_module + + device_mesh = tp_config['device_mesh'] + layer_plan = tp_config['layer_plan'] + parallelize_module( + module=model, + device_mesh=device_mesh, + parallelize_plan=layer_plan, + ) + + def prepare_fsdp_module( model: torch.nn.Module, optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]], @@ -229,10 +215,6 @@ def prepare_fsdp_module( auto_microbatching (bool, optional): Whether or not auto microbatching is enabled. te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234. """ - patch_pytorch() - - set_fsdp_default(fsdp_config) - # Check sync_module_states is True for mixed initialization or HSDP if fsdp_config['sync_module_states'] == False: rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0 @@ -319,31 +301,26 @@ def sync_hook(*args): sharding_strategy = SHARDING_MAP[sharding_map_key] kwargs = {} - if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'): - if 'device_mesh' in fsdp_config: - device_mesh_size = len(fsdp_config['device_mesh']) - if sharding_strategy in [ - ShardingStrategy.FULL_SHARD, - ShardingStrategy.SHARD_GRAD_OP, - ShardingStrategy.NO_SHARD, - ] and device_mesh_size != 1: - raise ValueError( - f'FSDP sharding strategy {sharding_map_key.upper()} requires a device mesh ' - f'of size 1 but got device mesh size of {device_mesh_size}.', - ) - elif sharding_strategy in [ - ShardingStrategy.HYBRID_SHARD, - ShardingStrategy._HYBRID_SHARD_ZERO2, - ] and device_mesh_size != 2: - raise ValueError( - f'FSDP sharding strategy {sharding_map_key.upper()} requires a device mesh ' - f'of size 2 but got device mesh size of {device_mesh_size}.', - ) - from torch.distributed._tensor import init_device_mesh - kwargs['device_mesh'] = init_device_mesh( - 'cuda', - tuple([int(x) for x in fsdp_config['device_mesh']]), + if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0') and 'device_mesh' in fsdp_config: + if fsdp_config['process_group'] is not None: + warnings.warn( + 'process_group and device_mesh are set for FSDP, so ignoring device_mesh. Please set process_group to None.', ) + else: + ndim = fsdp_config['device_mesh'].ndim + if ndim == 1 and sharding_strategy == ShardingStrategy.HYBRID_SHARD: + sharding_strategy = ShardingStrategy.FULL_SHARD + warnings.warn('HYBRID_SHARD is not supported with 1D device mesh. Using FULL_SHARD instead.') + elif ndim == 1 and sharding_strategy == ShardingStrategy._HYBRID_SHARD_ZERO2: + sharding_strategy = ShardingStrategy.SHARD_GRAD_OP + warnings.warn('_HYBRID_SHARD_ZERO2 is not supported with 1D device mesh. Using SHARD_GRAD_OP instead.') + elif ndim == 2 and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP: + sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 + warnings.warn('SHARD_GRAD_OP is not supported with 2D device mesh. Using _HYBRID_SHARD_ZERO2 instead.') + elif ndim == 2 and sharding_strategy == ShardingStrategy.FULL_SHARD: + sharding_strategy = ShardingStrategy.HYBRID_SHARD + warnings.warn('FULL_SHARD is not supported with 2D device mesh. Using HYBRID_SHARD instead.') + kwargs['device_mesh'] = fsdp_config['device_mesh'] cpu_offload = get_cpu_offload(cpu_offload=fsdp_config['cpu_offload']) @@ -382,7 +359,7 @@ def sync_hook(*args): process_group = None if fsdp_config['process_group'] is not None: process_group_dict = {'process_group': fsdp_config['process_group']} - process_group = _set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group'] + process_group = set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group'] backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config['backward_prefetch'].upper()] activation_checkpointing = fsdp_config['activation_checkpointing'] activation_cpu_offload = fsdp_config['activation_cpu_offload'] @@ -556,7 +533,7 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]: elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable): ret = obj.fsdp_wrap_fn(module) if isinstance(ret, dict): - ret = _set_custom_fsdp_module_kwargs(ret, process_group_cache) + ret = set_custom_fsdp_module_kwargs(ret, process_group_cache) if ret and auto_microbatching: module.register_forward_hook(sync_hook) module.register_full_backward_hook(sync_hook) diff --git a/composer/trainer/meta_safe_apply.py b/composer/distributed/meta_safe_apply.py similarity index 100% rename from composer/trainer/meta_safe_apply.py rename to composer/distributed/meta_safe_apply.py diff --git a/composer/distributed/mosaic_fsdp.py b/composer/distributed/mosaic_fsdp.py new file mode 100644 index 0000000000..c754a05156 --- /dev/null +++ b/composer/distributed/mosaic_fsdp.py @@ -0,0 +1,252 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""FSDP related configs and helper functions.""" + +import logging +import warnings +from typing import Any, Dict, Tuple, Union + +import torch +from packaging import version +from torch import distributed +from torch.distributed import ProcessGroup +from torch.distributed.fsdp import ( + BackwardPrefetch, + CPUOffload, + MixedPrecision, + ShardingStrategy, +) + +from composer.core import Precision +from composer.utils import VersionedDeprecationWarning, dist + +log = logging.getLogger(__name__) + +SHARDING_MAP = { + 'NO_SHARD': ShardingStrategy.NO_SHARD, + 'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP, + 'FULL_SHARD': ShardingStrategy.FULL_SHARD, +} + +if version.parse(torch.__version__) >= version.parse('2.1.0'): + SHARDING_MAP['_HYBRID_SHARD_ZERO2'] = ShardingStrategy._HYBRID_SHARD_ZERO2 + SHARDING_MAP['HYBRID_SHARD'] = ShardingStrategy.HYBRID_SHARD + +BACKWARD_PREFETCH_MAP = { + 'NONE': None, + 'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE, + 'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST, +} + + +def set_fsdp_default(fsdp_config: Dict[str, Any]): + """Modify fsdp_config to set default values for missing keys.""" + if 'process_group' in fsdp_config: + warnings.warn( + VersionedDeprecationWarning( + 'process_group is deprecated. Please specify `data_parallel_shard_degree` and `data_parallel_replicate_degree` instead.', + remove_version='0.24.0', + ), + ) + + if 'device_mesh' in fsdp_config: + warnings.warn( + VersionedDeprecationWarning( + 'device_mesh is deprecated. Please specify `data_parallel_shard_degree` and `data_parallel_replicate_degree` instead.', + remove_version='0.24.0', + ), + ) + if 'data_parallel_shard_degree' in fsdp_config or 'data_parallel_replicate_degree' in fsdp_config: + raise ValueError( + 'Cannot specify both `device_mesh` and `data_parallel_shard_degree` or `data_parallel_replicate_degree`. Please remove `device_mesh`.', + ) + device_mesh = fsdp_config.pop('device_mesh') + fsdp_config['data_parallel_shard_degree'] = device_mesh[0] + if len(device_mesh) > 1: + fsdp_config['data_parallel_replicate_degree'] = device_mesh[1] + + fsdp_config.setdefault('activation_checkpointing', False) + fsdp_config.setdefault('activation_checkpointing_reentrant', True) + fsdp_config.setdefault('activation_cpu_offload', False) + fsdp_config.setdefault('auto_wrap', True) + fsdp_config.setdefault('te_checkpoint_wrapper', False) + fsdp_config.setdefault('te_shard_fp8_weight', False) + fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST') + fsdp_config.setdefault('backward_prefetch_limit', 1) + fsdp_config.setdefault('cpu_offload', False) + fsdp_config.setdefault('data_parallel_shard_degree', -1) + fsdp_config.setdefault('data_parallel_replicate_degree', 1) + fsdp_config.setdefault('forward_prefetch', False) + fsdp_config.setdefault('forward_prefetch_limit', 1) + fsdp_config.setdefault('ignored_modules', None) + fsdp_config.setdefault('keep_low_precision_grads', False) + fsdp_config.setdefault('limit_all_gathers', True) + fsdp_config.setdefault('load_monolith_rank0_only', False) + fsdp_config.setdefault('load_planner', None) + fsdp_config.setdefault('mixed_precision', 'DEFAULT') + fsdp_config.setdefault('process_group', None) + fsdp_config.setdefault('save_planner', None) + fsdp_config.setdefault('sharded_ckpt_prefix_dir', 'ep{epoch}-ba{batch}') + fsdp_config.setdefault('sharding_strategy', 'FULL_SHARD') + fsdp_config.setdefault('state_dict_type', 'full') + fsdp_config.setdefault('sync_module_states', False) + fsdp_config.setdefault('use_orig_params', True) + fsdp_config.setdefault('verbose', False) + + return fsdp_config + + +def _get_torch_dtype(dtype: Union[Precision, str]): + """Convert common string representations of dtypes to torch dtypes.""" + dtype = dtype.value if isinstance(dtype, Precision) else dtype + if dtype in ['float32', 'torch.float32', 'fp32']: + return torch.float32 + elif dtype in ['float16', 'torch.float16', 'half', 'fp16', 'amp', 'amp_fp16']: + return torch.float16 + elif dtype in ['bfloat16', 'bfloat', 'torch.bfloat16', 'bf16', 'amp_bf16']: + return torch.bfloat16 + elif dtype in ['float8', 'torch.float8', 'fp8', 'amp_fp8']: + if hasattr(torch, 'float8'): + raise NotImplementedError('Torch has enabled float8. This should be updated to `return torch.float8`') + else: + warnings.warn('We use torch.bfloat16 by default for amp_fp8 as there is no fp8 datatype in PyTorch yet.') + return torch.bfloat16 + else: + raise ValueError(f'Not sure how to convert dtype={dtype} to a torch dtype.') + + +def get_mixed_precision(precision, mixed_precision='DEFAULT', keep_low_precision_grads=False): + """Helper function for configuring mixed_precision.""" + param_dtype = None + reduce_dtype = None + buffer_dtype = None + if isinstance(mixed_precision, dict): + param_dtype = mixed_precision.get('param_dtype', None) + if param_dtype is not None: + param_dtype = _get_torch_dtype(param_dtype) + reduce_dtype = mixed_precision.get('reduce_dtype', None) + if reduce_dtype is not None: + reduce_dtype = _get_torch_dtype(reduce_dtype) + buffer_dtype = mixed_precision.get('buffer_dtype', None) + if buffer_dtype is not None: + buffer_dtype = _get_torch_dtype(buffer_dtype) + elif isinstance(mixed_precision, str): + mixed_precision = mixed_precision.upper() + if mixed_precision == 'FULL': + pass + elif mixed_precision == 'DEFAULT': + param_dtype = _get_torch_dtype(precision) + reduce_dtype = torch.float32 + buffer_dtype = _get_torch_dtype(precision) + elif mixed_precision == 'PURE': + param_dtype = _get_torch_dtype(precision) + reduce_dtype = _get_torch_dtype(precision) + buffer_dtype = _get_torch_dtype(precision) + else: + raise ValueError(f'Unable to interpret mixed_precision={mixed_precision}') + else: + raise ValueError(f'Unable to interpret mixed_precision={mixed_precision}') + + mixed_precision = MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, + keep_low_precision_grads=keep_low_precision_grads, + ) + + return mixed_precision, param_dtype, reduce_dtype, buffer_dtype + + +def get_cpu_offload(cpu_offload=False): + """Helper function for configuring cpu_offload.""" + cpu_offload = CPUOffload(offload_params=True) if cpu_offload else None + if cpu_offload is not None: + raise ValueError('FSDP CPU Offload not supported yet.') + return cpu_offload + + +def _get_process_group(pg, process_group_cache=None): + """Helper function for configuring and/or retrieving process groups.""" + if pg is None or isinstance(pg, ProcessGroup): # Return as is, no caching + return pg + + world_size = dist.get_world_size() + local_world_size = dist.get_local_world_size() + + # Handle special str process_group cases + if pg == 'self': + pg = 'set1' + log.info(f"Converting process_group='self' to process_group='{pg}'") + elif pg == 'node': + pg = f'set{local_world_size}' + log.info(f"Converting process_group='node' to process_group='{pg}'") + elif pg == 'local_rank_across_nodes': + pg = f'mod{local_world_size}' + log.info(f"Converting process_group='local_rank_across_nodes' to process_group='{pg}'") + + # Handle str and Union[List[int], Tuple[int]] process_group cases + if isinstance(pg, str) and pg.startswith('set'): + k = int(pg.strip('set')) + world_size = dist.get_world_size() + if world_size % k != 0: + raise RuntimeError(f'{world_size} must be divisible by set size ({k})') + start = dist.get_global_rank() // k * k + ranks = tuple(range(start, start + k)) + elif isinstance(pg, str) and pg.startswith('mod'): + k = int(pg.strip('mod')) + world_size = dist.get_world_size() + if world_size % k != 0: + raise RuntimeError(f'{world_size} must be divisible by mod ({k})') + ranks = tuple(range(dist.get_global_rank() % k, world_size, k)) + elif isinstance(pg, (list, tuple)): + ranks = tuple(pg) + else: + raise ValueError(f'Unsure how to setup process_group={pg}') + + if process_group_cache is not None and ranks in process_group_cache: + log.info(f'Using cached progress group with {ranks=} on rank={dist.get_global_rank()}.') + return process_group_cache[ranks] + + log.info(f'Instantiating custom process groups with {ranks=} on rank={dist.get_global_rank()}.') + + ranks_per_subgroup_list = list(set(dist.all_gather_object(ranks))) + ( + current_group, + _subgroups, + ) = distributed.distributed_c10d.new_subgroups_by_enumeration( # type: ignore[reportGeneralTypeIssues] + ranks_per_subgroup_list, + ) + + if process_group_cache is not None: + process_group_cache[ranks] = current_group + return current_group + + +def set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dict[Tuple[int], Any]) -> Dict: + """Set custom module_kwargs per fsdp module.""" + if ('sharding_strategy' in module_kwargs and module_kwargs['sharding_strategy'] not in SHARDING_MAP.values()): + module_kwargs['sharding_strategy'] = SHARDING_MAP[module_kwargs['sharding_strategy'].upper()] + if 'backward_prefetch' in module_kwargs: + if module_kwargs['backward_prefetch'] not in BACKWARD_PREFETCH_MAP.values(): + module_kwargs['backward_prefetch'] = BACKWARD_PREFETCH_MAP[module_kwargs['backward_prefetch'].upper()] + if 'cpu_offload' in module_kwargs and not isinstance(module_kwargs['cpu_offload'], CPUOffload): + module_kwargs['cpu_offload'] = get_cpu_offload(cpu_offload=module_kwargs['cpu_offload'].upper()) + if 'mixed_precision' in module_kwargs and not isinstance(module_kwargs['mixed_precision'], MixedPrecision): + # `precision` needs to set `'mixed_precision'`, but `precision` is not part of fsdp kwargs + raise NotImplementedError( + f"Automated setting of custom per module mixed_precision is not implemented, but it can be set if `isinstance(module_kwargs['mixed_precision'], MixedPrecision)`", + ) + if 'process_group' in module_kwargs: + # Call on every process group if it is a tuple/list of non-ints + if type(module_kwargs['process_group']) in [ + list, + tuple, + ] and not all(isinstance(x, int) for x in module_kwargs['process_group']): + module_kwargs['process_group'] = tuple( + _get_process_group(pg, process_group_cache) for pg in module_kwargs['process_group'] + ) + else: + module_kwargs['process_group'] = _get_process_group(module_kwargs['process_group'], process_group_cache) + + return module_kwargs diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/_patch_pytorch.py similarity index 59% rename from composer/trainer/mosaic_fsdp_utils.py rename to composer/trainer/_patch_pytorch.py index ec75dfdfeb..c3a1f99bfa 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/_patch_pytorch.py @@ -6,377 +6,93 @@ # yapf: disable # isort: skip_file +# pyright: reportGeneralTypeIssues=false -"""Utilities for monkey patching FSDP.""" +"""PyTorch, especially PyTorch Distributed, monkeypatches.""" -import functools import logging import math -import warnings -import contextlib -from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union, cast, no_type_check +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, no_type_check import torch import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +from torch.distributed._shard.sharding_spec import ChunkShardingSpec import torch.nn as nn import torch.nn.functional as F from packaging import version -from torch import distributed -from torch.distributed import ProcessGroup from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec._internals import get_chunked_dim_size, get_split_size -from torch.distributed.distributed_c10d import get_process_group_ranks -from torch.distributed.fsdp import ( - BackwardPrefetch, CPUOffload, FullyShardedDataParallel, MixedPrecision, - ShardingStrategy, -) +from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform from torch.distributed.utils import _replace_by_prefix -from composer.core import Precision -from composer.utils import dist +log = logging.getLogger(__name__) -if TYPE_CHECKING: - if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse( - torch.__version__, - ) < version.parse('2.2.0'): - from torch.distributed.fsdp._common_utils import _FSDPState +def patch_pytorch(): + """Monkey patches pytorch functions based on pytorch version.""" + if version.parse(torch.__version__) < version.parse('2.1.1'): + # Monkey patch for torch < 2.1.1 ie torch == 2.1.0 -log = logging.getLogger(__name__) + # Monkey patch sharding method + ChunkShardingSpec.build_metadata = build_metadata -SHARDING_MAP = { - 'NO_SHARD': ShardingStrategy.NO_SHARD, - 'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP, - 'FULL_SHARD': ShardingStrategy.FULL_SHARD, -} - -if version.parse(torch.__version__) >= version.parse('2.1.0'): - SHARDING_MAP['_HYBRID_SHARD_ZERO2'] = ShardingStrategy._HYBRID_SHARD_ZERO2 - SHARDING_MAP['HYBRID_SHARD'] = ShardingStrategy.HYBRID_SHARD - -BACKWARD_PREFETCH_MAP = { - 'NONE': None, - 'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE, - 'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST, -} - -logger = logging.getLogger(__name__) - - -def _get_torch_dtype(dtype: Union[Precision, str]): - """Convert common string representations of dtypes to torch dtypes.""" - dtype = dtype.value if isinstance(dtype, Precision) else dtype - if dtype in ['float32', 'torch.float32', 'fp32']: - return torch.float32 - elif dtype in ['float16', 'torch.float16', 'half', 'fp16', 'amp', 'amp_fp16']: - return torch.float16 - elif dtype in ['bfloat16', 'bfloat', 'torch.bfloat16', 'bf16', 'amp_bf16']: - return torch.bfloat16 - elif dtype in ['float8', 'torch.float8', 'fp8', 'amp_fp8']: - if hasattr(torch, 'float8'): - raise NotImplementedError('Torch has enabled float8. This should be updated to `return torch.float8`') - else: - warnings.warn('We use torch.bfloat16 by default for amp_fp8 as there is no fp8 datatype in PyTorch yet.') - return torch.bfloat16 - else: - raise ValueError(f'Not sure how to convert dtype={dtype} to a torch dtype.') - - -def get_mixed_precision(precision, mixed_precision='DEFAULT', keep_low_precision_grads=False): - """Helper function for configuring mixed_precision.""" - param_dtype = None - reduce_dtype = None - buffer_dtype = None - if isinstance(mixed_precision, dict): - param_dtype = mixed_precision.get('param_dtype', None) - if param_dtype is not None: - param_dtype = _get_torch_dtype(param_dtype) - reduce_dtype = mixed_precision.get('reduce_dtype', None) - if reduce_dtype is not None: - reduce_dtype = _get_torch_dtype(reduce_dtype) - buffer_dtype = mixed_precision.get('buffer_dtype', None) - if buffer_dtype is not None: - buffer_dtype = _get_torch_dtype(buffer_dtype) - elif isinstance(mixed_precision, str): - mixed_precision = mixed_precision.upper() - if mixed_precision == 'FULL': - pass - elif mixed_precision == 'DEFAULT': - param_dtype = _get_torch_dtype(precision) - reduce_dtype = torch.float32 - buffer_dtype = _get_torch_dtype(precision) - elif mixed_precision == 'PURE': - param_dtype = _get_torch_dtype(precision) - reduce_dtype = _get_torch_dtype(precision) - buffer_dtype = _get_torch_dtype(precision) - else: - raise ValueError(f'Unable to interpret mixed_precision={mixed_precision}') - else: - raise ValueError(f'Unable to interpret mixed_precision={mixed_precision}') - - mixed_precision = MixedPrecision( - param_dtype=param_dtype, - reduce_dtype=reduce_dtype, - buffer_dtype=buffer_dtype, - keep_low_precision_grads=keep_low_precision_grads, - ) + # Monkey patch partial state dict handling + from torch.distributed.fsdp import _state_dict_utils - return mixed_precision, param_dtype, reduce_dtype, buffer_dtype - - -def get_cpu_offload(cpu_offload=False): - """Helper function for configuring cpu_offload.""" - cpu_offload = CPUOffload(offload_params=True) if cpu_offload else None - if cpu_offload is not None: - raise ValueError('FSDP CPU Offload not supported yet.') - return cpu_offload - - -def _get_process_group(pg, process_group_cache=None): - """Helper function for configuring and/or retrieving process groups.""" - if pg is None or isinstance(pg, ProcessGroup): # Return as is, no caching - return pg - - world_size = dist.get_world_size() - local_world_size = dist.get_local_world_size() - - # Handle special str process_group cases - if pg == 'self': - pg = 'set1' - log.info(f"Converting process_group='self' to process_group='{pg}'") - elif pg == 'node': - pg = f'set{local_world_size}' - log.info(f"Converting process_group='node' to process_group='{pg}'") - elif pg == 'local_rank_across_nodes': - pg = f'mod{local_world_size}' - log.info(f"Converting process_group='local_rank_across_nodes' to process_group='{pg}'") - - # Handle str and Union[List[int], Tuple[int]] process_group cases - if isinstance(pg, str) and pg.startswith('set'): - k = int(pg.strip('set')) - world_size = dist.get_world_size() - if world_size % k != 0: - raise RuntimeError(f'{world_size} must be divisible by set size ({k})') - start = dist.get_global_rank() // k * k - ranks = tuple(range(start, start + k)) - elif isinstance(pg, str) and pg.startswith('mod'): - k = int(pg.strip('mod')) - world_size = dist.get_world_size() - if world_size % k != 0: - raise RuntimeError(f'{world_size} must be divisible by mod ({k})') - ranks = tuple(range(dist.get_global_rank() % k, world_size, k)) - elif isinstance(pg, (list, tuple)): - ranks = tuple(pg) - else: - raise ValueError(f'Unsure how to setup process_group={pg}') - - if process_group_cache is not None and ranks in process_group_cache: - log.info(f'Using cached progress group with {ranks=} on rank={dist.get_global_rank()}.') - return process_group_cache[ranks] - - log.info(f'Instantiating custom process groups with {ranks=} on rank={dist.get_global_rank()}.') - - ranks_per_subgroup_list = list(set(dist.all_gather_object(ranks))) - ( - current_group, - _subgroups, - ) = distributed.distributed_c10d.new_subgroups_by_enumeration(ranks_per_subgroup_list) - - if process_group_cache is not None: - process_group_cache[ranks] = current_group - return current_group - - -def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dict[Tuple[int], Any]) -> Dict: - """Set custom module_kwargs per fsdp module.""" - if ('sharding_strategy' in module_kwargs and module_kwargs['sharding_strategy'] not in SHARDING_MAP.values()): - module_kwargs['sharding_strategy'] = SHARDING_MAP[module_kwargs['sharding_strategy'].upper()] - if 'backward_prefetch' in module_kwargs: - if module_kwargs['backward_prefetch'] not in BACKWARD_PREFETCH_MAP.values(): - module_kwargs['backward_prefetch'] = BACKWARD_PREFETCH_MAP[module_kwargs['backward_prefetch'].upper()] - if 'cpu_offload' in module_kwargs and not isinstance(module_kwargs['cpu_offload'], CPUOffload): - module_kwargs['cpu_offload'] = get_cpu_offload(cpu_offload=module_kwargs['cpu_offload'].upper()) - if 'mixed_precision' in module_kwargs and not isinstance(module_kwargs['mixed_precision'], MixedPrecision): - # `precision` needs to set `'mixed_precision'`, but `precision` is not part of fsdp kwargs - raise NotImplementedError( - f"Automated setting of custom per module mixed_precision is not implemented, but it can be set if `isinstance(module_kwargs['mixed_precision'], MixedPrecision)`", - ) - if 'process_group' in module_kwargs: - # Call on every process group if it is a tuple/list of non-ints - if type(module_kwargs['process_group']) in [ - list, tuple, - ] and not all(isinstance(x, int) for x in module_kwargs['process_group']): - module_kwargs['process_group'] = tuple( - _get_process_group(pg, process_group_cache) for pg in module_kwargs['process_group'] - ) - else: - module_kwargs['process_group'] = _get_process_group(module_kwargs['process_group'], process_group_cache) + _state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook) - return module_kwargs + # Allow 2D HSDP + from torch.distributed.fsdp import _runtime_utils + _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None -def _custom_recursive_wrap_t2p0p1( - module: nn.Module, - auto_wrap_policy: Callable, - wrapper_cls: Callable, - ignored_modules: Set[nn.Module], - ignored_params: Set[nn.Parameter], - process_group_cache: Dict[Tuple[int], Any], - only_wrap_children: bool = False, - **kwargs: Any, -) -> Tuple[nn.Module, int]: - """Updates FSDPs _recursive_wrap to enable module_kwargs and custom process_group cache. - - torch version must be 2.0.1. - - modified version of - https://github.com/pytorch/pytorch/blob/96ca226a7332be0d8f3d6159d0c797e032ab0721/torch/distributed/fsdp/wrap.py#L320 - which recursively wraps modules as FSDP modules for parameter sharding. - This modification enables the user to pass custom FSDP arguments for every wrapped module. - The added process_group_cache enables different FSDP modules to, when appropriate, use the - same process group instead of instantiating a new process group. - - Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns - ``True`` with ``wrapper_cls``. - - Args: - module (nn.Module): Module to recursively wrap. - auto_wrap_policy (Callable): A callable representing a policy that - determines which modules to recursively wrap with ``wrapper_cls``. - wrapper_cls: wrapper_cls - ignored_modules (Set[torch.nn.Module]): Modules to ignore when - wrapping. - ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when - wrapping; these should be the parameters contained in the modules - in ``ignored_modules``. - process_group_cache (Dict[Tuple[int], Any]): a cache of process_group to - use instead of potentially instantiating a new process_group - only_wrap_children: warp only children - Returns: - (nn.Module, int): - ``module`` after wrapping and the numel recursively wrapped. - """ - from torch.distributed.fsdp.wrap import _wrap - - assert auto_wrap_policy is not None, 'Must specify auto_wrap_policy.' - assert wrapper_cls is not None, 'Must specify wrapper_cls' - # Make sure no child is already wrapped. - for _, child in module.named_modules(): - if child in ignored_modules: - continue - try: - assert not isinstance(child, cast(type, wrapper_cls)) - except TypeError: - # wrapper_cls is a function as opposed to a class type, just bypass above check. - pass - - # We count all params, assuming none of them are already wrapped. - nonwrapped_numel = sum(p.numel() for p in module.parameters() if p not in ignored_params) - - assert auto_wrap_policy is not None - if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): - total_wrapped_numel = 0 - # Iterate through the children, recursively wrap if necessary - for name, child in module.named_children(): - if child in ignored_modules: - continue - wrapped_child, num_wrapped_params = _custom_recursive_wrap_t2p0p1( - module=child, - auto_wrap_policy=auto_wrap_policy, - wrapper_cls=wrapper_cls, - ignored_modules=ignored_modules, - ignored_params=ignored_params, - process_group_cache=process_group_cache, - **kwargs, - ) - setattr(module, name, wrapped_child) - # Keep track of how many parameters have been wrapped - total_wrapped_numel += num_wrapped_params - # decide if we need to wrap the current module, - # since the left over parameters exceed the number of params to wrap - remainder = nonwrapped_numel - total_wrapped_numel - module_kwargs = auto_wrap_policy(module=module, recurse=False, nonwrapped_numel=remainder) - if not only_wrap_children and module_kwargs: - # CHANGE: We modify the original code to support custom FSDP kwargs and add - # the process_group_cache to avoid instantiating a new process group. - module_kwargs = module_kwargs if isinstance(module_kwargs, dict) else {} - module_kwargs = _set_custom_fsdp_module_kwargs(module_kwargs, process_group_cache) - - final_kwargs = {**kwargs, **module_kwargs} - - if final_kwargs.get('process_group', None) is not None: - _pg_ranks = distributed.get_process_group_ranks(final_kwargs['process_group']) - _meta_init = any(p.device.type == 'meta' for p in module.parameters()) - if (_meta_init and len(_pg_ranks) != dist.get_world_size() and final_kwargs.get('use_orig_params')): - raise NotImplementedError( - f'FSDP with custom process groups cannot use `use_orig_params: True` when using meta init.', - ) - - # Leaf node or final wrapping of the remainder both happen here. - return _wrap(module, wrapper_cls, **final_kwargs), nonwrapped_numel - else: - return module, total_wrapped_numel - return module, 0 + elif version.parse(torch.__version__) < version.parse('2.1.3'): + # Monkey patch for torch < 2.1.3 ie torch == 2.1.1, 2.1.2 + # Allow 2D HSDP + from torch.distributed.fsdp import _runtime_utils + _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None -def _custom_auto_wrap_t2p0p1( - auto_wrap_kwargs: Dict[str, Any], - fsdp_kwargs: Dict[str, Any], - module_wrapper_cls: Any, # e.g. `FullyShardedDataParallel` -) -> None: - """Updates _auto_wrap to enable module_kwargs. + elif version.parse(torch.__version__) < version.parse('2.2.1'): + # Monkey patch for torch < 2.2.1 ie torch == 2.2.0 - torch version must be 2.0.1. + # Allow 2D HSDP + from torch.distributed.fsdp import _runtime_utils + _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None - modified version of - https://github.com/pytorch/pytorch/blob/96ca226a7332be0d8f3d6159d0c797e032ab0721/torch/distributed/fsdp/_wrap_utils.py#L31 - FSDP's _auto_wrap recursively wraps modules as FSDP modules for parameter sharding. - This modification enables the user to pass custom FSDP arguments for every wrapped module. - The added process_group_cache enables different FSDP modules to, when appropriate, use the - same process group instead of instantiating a new process group. + elif version.parse(torch.__version__) < version.parse('2.2.3'): + # Monkey patch for torch < 2.2.3 ie torch == 2.2.1/2.2.2 currently - Recursively auto wraps the root module given by the key "module" in - ``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and - ``fsdp_kwargs``. + # Fix memory leak for FSDP.optim_state_dict_to_load + # https://github.com/pytorch/pytorch/issues/116553 + from torch.distributed.fsdp import _optim_utils - Precondition: ``auto_wrap_policy`` contains the arguments expected by - ``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``. - ``fsdp_kwargs`` contains all FSDP arguments except ``module``. - """ - from torch.distributed.fsdp._utils import _contains_batchnorm, _override_batchnorm_mixed_precision - from torch.distributed.fsdp.wrap import _FSDPPolicy, _or_policy, _wrap_batchnorm_individually - - auto_wrap_policy = auto_wrap_kwargs['auto_wrap_policy'] - # Support new way to pass an auto wrap policy - if isinstance(auto_wrap_policy, _FSDPPolicy): - auto_wrap_policy = auto_wrap_policy.policy - root_module = auto_wrap_kwargs['module'] - assert auto_wrap_policy is not None - # For auto wrapping, submodules should not already be wrapped with FSDP - # since double wrapping is not supported - for module_name, module in root_module.named_modules(): - if isinstance(module, module_wrapper_cls): - raise ValueError( - f'Expected {module_name} to NOT be FullyShardedDataParallel ' - 'if using an `auto_wrap_policy`', - ) - mixed_precision = fsdp_kwargs['mixed_precision'] - if mixed_precision is not None and _contains_batchnorm(root_module): - _override_batchnorm_mixed_precision(root_module) - auto_wrap_policy = functools.partial(_or_policy, policies=[_wrap_batchnorm_individually, auto_wrap_policy]) - warnings.warn( - 'Both mixed precision and an `auto_wrap_policy` were specified ' - 'for FSDP, where the wrapped module has batch norm submodules. ' - 'The batch norm submodules will be wrapped as separate FSDP ' - 'instances with mixed precision disabled since some batch norm ' - 'kernels do not support low precision.', - ) - auto_wrap_kwargs['auto_wrap_policy'] = auto_wrap_policy + _optim_utils._shard_orig_param_state = _shard_orig_param_state + + elif version.parse(torch.__version__) < version.parse('2.3.1'): + # Monkey patch for torch < 2.3.1 ie torch == 2.3.0 + + # Monkeypatch _flat_param.py to fix 2D with SHARD_GRAD_OP + # Issue: https://github.com/pytorch/pytorch/issues/123272 + from torch.distributed.fsdp import _flat_param + + _flat_param._same_storage = _same_storage + + # Monkeypatch state_dict to get FQNs correctly. + # Issue: https://github.com/pytorch/pytorch/pull/124698 + from torch.distributed.checkpoint import state_dict + + state_dict.set_model_state_dict = set_model_state_dict + state_dict.set_optimizer_state_dict = set_optimizer_state_dict + state_dict._get_fqns = _get_fqns + + # Monkeypatch for ND child submeshes + # PR: https://github.com/pytorch/pytorch/pull/119752 + from torch.distributed.device_mesh import DeviceMesh, _MeshEnv - # CHANGE: Add process group cache and call our custom _recursive_wrap - auto_wrap_kwargs['process_group_cache'] = {} - _custom_recursive_wrap_t2p0p1(**auto_wrap_kwargs, **fsdp_kwargs) + _MeshEnv.create_child_mesh = create_child_mesh + DeviceMesh.__getitem__ = device_mesh__getitem__ + DeviceMesh.__init__ = device_mesh__init__ def build_metadata( @@ -419,7 +135,7 @@ def build_metadata( @no_type_check def _sharded_pre_load_state_dict_hook( module: nn.Module, - fsdp_state: '_FSDPState', + fsdp_state, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -443,7 +159,7 @@ def _sharded_pre_load_state_dict_hook( return handle = _module_handle(fsdp_state, module) - if not handle.uses_sharded_strategy: + if not handle.uses_sharded_strategy: # type: ignore raise RuntimeError( 'load_sharded_state_dict can only be called when parameters ' 'are flattened and sharded.', @@ -458,7 +174,7 @@ def _sharded_pre_load_state_dict_hook( try: param = state_dict.pop(fqn_from_global_root) except KeyError: - logger.warning( + log.warning( f'Did not find param with FQN {fqn_from_global_root}, skipping it. ' # noqa: G004 'The weight will not be filled if you expect it to be.', ) @@ -504,8 +220,8 @@ def _sharded_pre_load_state_dict_hook( tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) state_dict[fqn_from_global_root] = tensor else: - if param.device != fsdp_state._device_mesh.device_type: - param = param.to(fsdp_state._device_mesh.device_type) + if param.device != fsdp_state._device_mesh.device_type: # type: ignore + param = param.to(fsdp_state._device_mesh.device_type) # type: ignore param = param.redistribute(device_mesh=param.device_mesh, placements=[Replicate()]) state_dict[fqn_from_global_root] = param.to_local() @@ -727,7 +443,7 @@ def set_optimizer_state_dict( ) import dataclasses from collections import defaultdict, ChainMap - from typing import Dict, List, Set, TYPE_CHECKING + from typing import Dict, List, Set from torch.distributed.checkpoint.planner import SavePlan, WriteItem from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata @@ -831,13 +547,13 @@ def create_child_mesh( ) res_sub_mesh = sub_mesh - res_sub_mesh._dim_group_infos = [ # type: ignore[possibly-undefined] + res_sub_mesh._dim_group_infos = [ # type: ignore device_mesh._dim_group_infos[mesh_dim] for mesh_dim in mesh_dims ] # Assign the current DeviceMesh as the parent of the child DeviceMesh. - self.child_to_parent_mapping[res_sub_mesh] = device_mesh - return res_sub_mesh + self.child_to_parent_mapping[res_sub_mesh] = device_mesh # type: ignore + return res_sub_mesh # type: ignore from torch.distributed.device_mesh import _mesh_resources diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py deleted file mode 100644 index 7ebc800f88..0000000000 --- a/composer/trainer/mosaic_fsdp.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -# Released under BSD 3-Clause License, -# Copyright (c) Facebook, Inc. and its affiliates. - -"""Monkey patch FSDPs _auto_wrap to enable module_kwargs and custom process_group cache and ChunkShardingSpec to enable sharding over all gpus.""" - -# pyright: reportGeneralTypeIssues=false -import torch -from packaging import version -from torch.distributed._shard.sharding_spec import ChunkShardingSpec - - -def patch_pytorch(): - """Monkey patches pytorch functions based on pytorch version.""" - if version.parse(torch.__version__) < version.parse('2.1.1'): - # Monkey patch for torch < 2.1.1 ie torch == 2.1.0 - - # Monkey patch sharding method - from composer.trainer.mosaic_fsdp_utils import build_metadata - - ChunkShardingSpec.build_metadata = build_metadata - - # Monkey patch partial state dict handling - from torch.distributed.fsdp import _state_dict_utils - - from composer.trainer.mosaic_fsdp_utils import _sharded_pre_load_state_dict_hook - - _state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook) - - # Allow 2D HSDP - from torch.distributed.fsdp import _runtime_utils - _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None - - elif version.parse(torch.__version__) < version.parse('2.1.3'): - # Monkey patch for torch < 2.1.3 ie torch == 2.1.1, 2.1.2 - - # Allow 2D HSDP - from torch.distributed.fsdp import _runtime_utils - _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None - - elif version.parse(torch.__version__) < version.parse('2.2.1'): - # Monkey patch for torch < 2.2.1 ie torch == 2.2.0 - - # Allow 2D HSDP - from torch.distributed.fsdp import _runtime_utils - _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None - - elif version.parse(torch.__version__) < version.parse('2.2.3'): - # Monkey patch for torch < 2.2.3 ie torch == 2.2.1/2.2.2 currently - - # Fix memory leak for FSDP.optim_state_dict_to_load - # https://github.com/pytorch/pytorch/issues/116553 - from torch.distributed.fsdp import _optim_utils - - from composer.trainer.mosaic_fsdp_utils import _shard_orig_param_state - _optim_utils._shard_orig_param_state = _shard_orig_param_state - - elif version.parse(torch.__version__) < version.parse('2.3.1'): - # Monkey patch for torch < 2.3.1 ie torch == 2.3.0 - - # Monkeypatch _flat_param.py to fix 2D with SHARD_GRAD_OP - # Issue: https://github.com/pytorch/pytorch/issues/123272 - from torch.distributed.fsdp import _flat_param - - from composer.trainer.mosaic_fsdp_utils import _same_storage - _flat_param._same_storage = _same_storage - - # Monkeypatch state_dict to get FQNs correctly. - # Issue: https://github.com/pytorch/pytorch/pull/124698 - from torch.distributed.checkpoint import state_dict - - from composer.trainer.mosaic_fsdp_utils import _get_fqns, set_model_state_dict, set_optimizer_state_dict - state_dict.set_model_state_dict = set_model_state_dict - state_dict.set_optimizer_state_dict = set_optimizer_state_dict - state_dict._get_fqns = _get_fqns - - # Monkeypatch for ND child submeshes - # PR: https://github.com/pytorch/pytorch/pull/119752 - from torch.distributed.device_mesh import DeviceMesh, _MeshEnv - - from composer.trainer.mosaic_fsdp_utils import create_child_mesh, device_mesh__getitem__, device_mesh__init__ - _MeshEnv.create_child_mesh = create_child_mesh - DeviceMesh.__getitem__ = device_mesh__getitem__ - DeviceMesh.__init__ = device_mesh__init__ diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 26a9f87e5e..961847f85b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -79,6 +79,16 @@ get_precision_context, ) from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU +from composer.distributed import ( + DDPSyncStrategy, + ddp_sync_context, + fix_batch_precision_for_deepspeed, + parse_deepspeed_config, + prepare_ddp_module, + prepare_fsdp_module, + prepare_tp_module, + set_fsdp_default, +) from composer.loggers import ( ConsoleLogger, Logger, @@ -93,21 +103,15 @@ from composer.models import ComposerModel from composer.optim import ComposerScheduler, DecoupledSGDW, compile_composer_scheduler from composer.profiler import Profiler -from composer.trainer._deepspeed import _fix_batch_precision_for_deepspeed, _parse_deepspeed_config +from composer.trainer._patch_pytorch import patch_pytorch from composer.trainer._scale_schedule import scale_pytorch_scheduler from composer.trainer._scaler import ClosureGradScaler -from composer.trainer.dist_strategy import ( - DDPSyncStrategy, - ddp_sync_context, - prepare_ddp_module, - prepare_fsdp_module, - set_fsdp_default, -) from composer.utils import ( ExportFormat, MissingConditionalImportError, ObjectStore, Transform, + VersionedDeprecationWarning, checkpoint, dist, ensure_tuple, @@ -911,7 +915,20 @@ class Trainer: disable FSDP, set to ``None``. (default: ``None``) fsdp_auto_wrap (bool, optional): option to let trainer wrap the module, or if the module is already wrapped outside, allow the user to disable auto-wrapping. + parallelism_config (Dict[str, Any], optional): Configuration for parallelism options. + Currently supports fsdp and tensor parallelism, whose respective configs are specified + as the keys ``fsdp`` and ``tp``. (default: ``None``) + + For `parallelism_config['fsdp']`, see :doc:`FSDP Documentation ` + for more details. To use FSDP with default values, set to the empty dictionary ``{}``. To + disable FSDP, set to ``None`` or remove the key from the dictionary. + For `parallelism_config['tp']`, see :doc:`TP Documentation ` + for more details. To use Tensor Parallelism with default values, set to the empty dictionary ``{}``. To + disable Tensor Parallelism, set to ``None`` or remove the key from the dictionary. + + .. note:: This parameter is experimental and subject to change without standard deprecation + cycles. device (Device | str, optional): The device to use for training, which can be ``'cpu'``, ``'gpu'``, ``'tpu'``, or ``'mps'``. (default: ``None``) @@ -1051,10 +1068,11 @@ def __init__( # Graceful Resumption autoresume: bool = False, - # DeepSpeed + # Parallelism deepspeed_config: Optional[Dict[str, Any]] = None, fsdp_config: Optional[Dict[str, Any]] = None, fsdp_auto_wrap: bool = True, + parallelism_config: Optional[Dict[str, Any]] = None, # System/Numerics device: Optional[Union[str, Device]] = None, @@ -1079,7 +1097,6 @@ def __init__( # compile config for PyTorch 2.0 or higher compile_config: Optional[Dict[str, Any]] = None, ): - self.auto_log_hparams = auto_log_hparams self.python_log_level = python_log_level if self.python_log_level is not None: @@ -1093,6 +1110,7 @@ def __init__( ) logging.getLogger('composer').setLevel(self.python_log_level.upper()) + # Algorithms algorithms = list(ensure_tuple(algorithms)) # Device @@ -1105,7 +1123,7 @@ def __init__( precision = Precision(precision) _validate_precision(precision, device) - # check if provided model is compiled or not + # Check if provided model is compiled or not is_model_compiled = False if isinstance(model, OptimizedModule): log.warning( @@ -1161,10 +1179,56 @@ def __init__( assert not isinstance(device_train_microbatch_size, str) # Distributed - if deepspeed_config is not None or fsdp_config is not None or dist.get_world_size() > 1: + if fsdp_config is not None: + warnings.warn( + VersionedDeprecationWarning( + "fsdp_config is deprecated. Please use parallelism_config['fsdp'] instead.", + remove_version='0.26.0', + ), + ) + if parallelism_config is None: + parallelism_config = {} + if parallelism_config.get('fsdp') is not None: + raise ValueError( + 'fsdp_config is specified in both fsdp_config and parallelism_config. Please specify it in only in parallelism_config.', + ) + parallelism_config['fsdp_config'] = fsdp_config + if not fsdp_auto_wrap: + warnings.warn( + VersionedDeprecationWarning( + "fsdp_auto_wrap=False is deprecated. Please use parallelism_config['fsdp']['auto_wrap'] instead.", + remove_version='0.26.0', + ), + ) + if parallelism_config is None: + parallelism_config = {} + if parallelism_config.get('fsdp') is None: + parallelism_config['fsdp'] = {} + parallelism_config['fsdp']['auto_wrap'] = fsdp_auto_wrap + if parallelism_config is not None: + # Set defaults and create shallow copies of configs to avoid changing user's config + parallelism_config = {**parallelism_config} + if parallelism_config.get('fsdp', None) is not None: + parallelism_config['fsdp'] = set_fsdp_default({**parallelism_config['fsdp']}) + if parallelism_config.get('tp', None) is not None: + parallelism_config['tp'] = {**parallelism_config['tp']} + # Remove empty configs + for key in list(parallelism_config.keys()): + if parallelism_config[key] == None: + del parallelism_config[key] + if len(parallelism_config) == 0: + parallelism_config = None + if deepspeed_config is not None and parallelism_config is not None: + raise ValueError( + 'Both deepspeed_config and parallelism_config are specified but incompatible. Please specify only one.', + ) + if deepspeed_config is not None or parallelism_config is not None or dist.get_world_size() > 1: # Deepspeed and FSDP both require torch.distributed to be initialized, even if the world size is 1 # And torch.distributed is always required for multi-rank training dist.initialize_dist(device, dist_timeout) + if parallelism_config is not None or deepspeed_config is not None: + # Patch PyTorch to fix distributed bugs + patch_pytorch() # Reproducibility rank_zero_seed, seed = _distribute_and_get_random_seed(seed, device) @@ -1197,8 +1261,8 @@ def __init__( raise NotImplementedError(f'Only one optimizer is supported; found {num_optimizers} optimizers') # Move the model and optimizers to the device - if deepspeed_config is None and fsdp_config is None: - # check if model is already on tpu + if deepspeed_config is None and parallelism_config is None: + # Check if model is already on tpu if isinstance(device, DeviceTPU) and 'xla' not in str(next(model.parameters()).device): raise ValueError( 'Use model.to(xm.xla_device()) to set the model to the TPU before providing to the trainer.', @@ -1234,8 +1298,7 @@ def __init__( run_name=run_name, save_metrics=save_metrics, deepspeed_config=deepspeed_config, - fsdp_config=set_fsdp_default(fsdp_config) if fsdp_config is not None else None, - fsdp_auto_wrap=fsdp_auto_wrap, + parallelism_config=parallelism_config, ) # Console Logging @@ -1410,7 +1473,7 @@ def __init__( if latest_remote_file_name is not None: latest_remote_file_name = partial_format(latest_remote_file_name, **mlflow_format_kwargs) - # Log hparams. + # Log hparams if self.auto_log_hparams: locs = locals() if 'cb' in locs: @@ -1423,7 +1486,7 @@ def __init__( self.logger.log_hyperparameters({'composer_version': composer_env_dict['composer_version']}) self.logger.log_hyperparameters({'composer_commit_hash': str(composer_env_dict['composer_commit_hash'])}) - # Log gpus and nodes. + # Log gpus and nodes device_name = self.state.device.__class__.__name__.lstrip('Device').lower() self.logger.log_hyperparameters({ 'num_nodes': int(dist.get_world_size() / dist.get_local_world_size()), @@ -1553,12 +1616,22 @@ def __init__( self._original_model = self.state.model # If using PyTorch DDP, the model must be loaded before it is wrapped with DDP. - # If using DeepSpeed, the engine must be initialized before the model is loaded. + # If using TP, the model must be wrapped before FSDP. # If using FSDP, the model must be wrapped and then loaded unless loading a monolith # checkpoint on rank 0 only, in which case the model be loaded before it is wrapped. + # If using DeepSpeed, the engine must be initialized before the model is loaded. + + # TP wrap + if self.state.tp_config is not None: + with reproducibility.seed_context(self.state.rank_zero_seed): + prepare_tp_module( + model, + self.state.tp_config, + ) # FSDP wrap if not using monolith checkpoint on rank 0 only - if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_monolith_rank0_only: + if self.state.fsdp_config is not None and self.state.fsdp_config['auto_wrap' + ] and not self.state.load_monolith_rank0_only: with reproducibility.seed_context(self.state.rank_zero_seed): prepare_fsdp_module( model, @@ -1588,7 +1661,7 @@ def __init__( conda_package='deepspeed>=0.5.5', conda_channel=None, ) from e - self.state.deepspeed_config = _parse_deepspeed_config(self.state.deepspeed_config, state=self.state) + self.state.deepspeed_config = parse_deepspeed_config(self.state.deepspeed_config, state=self.state) optimizer = ensure_tuple(self.state.optimizers)[0] log.debug('Initializing deepspeed') (self.state.model, self.state.optimizers, _, _) = deepspeed.initialize( @@ -1712,6 +1785,7 @@ def __init__( if wandb.run is None: load_object_store.init(self.state, self.logger) _, _, parsed_load_path = parse_uri(load_path) + self._rng_state = checkpoint.load_checkpoint( state=self.state, logger=self.logger, @@ -1728,7 +1802,10 @@ def __init__( # FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if # load_monolith_rank0_only=True but no checkpoint was loaded. - if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_monolith_rank0_only: + if ( + not self.state.fsdp_enabled and self.state.fsdp_config is not None and + self.state.fsdp_config['auto_wrap'] and self.state.load_monolith_rank0_only + ): with reproducibility.seed_context(self.state.rank_zero_seed): prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) @@ -2416,7 +2493,7 @@ def _train_loop(self) -> None: rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch) if self.state.deepspeed_enabled: - self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision) + self.state.batch = fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision) self.engine.run_event(Event.AFTER_DATALOADER) @@ -3010,7 +3087,7 @@ def predict_batch_end(self, state: State, logger: Logger) -> None: # Fix the batch if using DeepSpeed if self.state.deepspeed_enabled: - self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision) + self.state.batch = fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision) self.engine.run_event(Event.PREDICT_BATCH_START) @@ -3294,7 +3371,7 @@ def _eval_loop( last_batch = self.state.eval_timestamp.sample + batch_num_samples >= dataset_len if self.state.deepspeed_enabled: - self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision) + self.state.batch = fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision) self.engine.run_event(Event.EVAL_BATCH_START) diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index ea5f4cd14e..7829c9fe76 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -50,6 +50,7 @@ from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection from composer.utils.misc import ( STR_TO_DTYPE, + ParallelismType, add_vision_dataset_transform, create_interval_scheduler, get_free_tcp_port, @@ -145,4 +146,5 @@ 'get_compressor', 'KNOWN_COMPRESSORS', 'STR_TO_DTYPE', + 'ParallelismType', ] diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 37dd9ed8eb..ee52e9818e 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -25,6 +25,7 @@ from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner +from torch.distributed.distributed_c10d import ProcessGroup from composer.utils import dist, reproducibility from composer.utils.compression import get_compressor, is_compressed_pt @@ -37,7 +38,7 @@ is_tar, parse_uri, ) -from composer.utils.misc import is_model_deepspeed, partial_format +from composer.utils.misc import ParallelismType, is_model_deepspeed, partial_format from composer.utils.object_store import ObjectStore from composer.utils.retrying import retry @@ -614,7 +615,7 @@ def load_sharded_checkpoint( source_path=source_path, destination_path=str(Path(rank0_download_tempdir) / Path('checkpoints')), object_store=object_store, - device_mesh=state.fsdp_device_mesh, + device_mesh=state.device_mesh, ) else: storage_reader = FileSystemReaderWithValidation(source_path) @@ -1033,8 +1034,10 @@ def get_save_filename( return PartialFilePath(filename).format(state, is_deepspeed) # Sharded checkpoints get their own little folder. - assert state.sharded_ckpt_prefix_dir is not None - save_dirpath = Path(Path(filename).parent) / Path(state.sharded_ckpt_prefix_dir) + assert state.fsdp_config is not None + remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir'] + assert remote_prefix is not None + save_dirpath = Path(Path(filename).parent) / Path(remote_prefix) save_dirpath = format_name_with_dist_and_time(str(save_dirpath), state.run_name, state.timestamp) # New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcp’ # e.g. path/to/my/checkpoints/ep1-ba2/__1_0.distcp @@ -1110,12 +1113,14 @@ def _save_checkpoint( log.debug(f'Saving sharded checkpoints to {save_filename}...') process_group = None - device_mesh = state.fsdp_device_mesh - if device_mesh is not None and device_mesh.ndim == 2: + device_mesh = state.device_mesh + if device_mesh is not None and device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in device_mesh.mesh_dim_names: # If hybrid shard, only rank in first replica saves - expect_file = device_mesh.get_local_rank(mesh_dim=0) == 0 + hsdp_index = device_mesh.mesh_dim_names.index(ParallelismType.DATA_PARALLEL_REPLICATE.value) + expect_file = device_mesh.get_local_rank(mesh_dim=hsdp_index) == 0 if expect_file: process_group = device_mesh.get_group(1) # Shard process_group for first replica + assert isinstance(process_group, ProcessGroup) # For type checker log.debug(f'Saving on global_rank={dist.get_global_rank()}, {expect_file=}') else: expect_file = True @@ -1124,7 +1129,8 @@ def _save_checkpoint( if version.parse(torch.__version__) >= version.parse('2.3.0'): save_planner = state.fsdp_config['save_planner'] if save_planner is None: - from composer.trainer.mosaic_fsdp_utils import SavePlannerWithDedupFix + from composer.trainer._patch_pytorch import SavePlannerWithDedupFix + save_planner = SavePlannerWithDedupFix() dist_cp.save( state_dict=state_dict, diff --git a/composer/utils/misc.py b/composer/utils/misc.py index 88a1366336..31480a11b6 100644 --- a/composer/utils/misc.py +++ b/composer/utils/misc.py @@ -15,6 +15,8 @@ from torchvision import transforms from torchvision.datasets import VisionDataset +from composer.utils.string_enum import StringEnum + if TYPE_CHECKING: from composer.core import Event, State, Time @@ -39,6 +41,19 @@ } +class ParallelismType(StringEnum): + """Enum class for different parallelism types in the device mesh. + + Attributes: + DATA_PARALLEL_SHARD: Data parallel shard dimension. + DATA_PARALLEL_REPLICATE: Data parallel replicate dimension. + TENSOR_PARALLEL: Tensor parallel dimension. + """ + DATA_PARALLEL_SHARD = 'data_parallel_shard' + DATA_PARALLEL_REPLICATE = 'data_parallel_replicate' + TENSOR_PARALLEL = 'tensor_parallel' + + def create_interval_scheduler( interval: Union[str, int, 'Time'], include_end_of_training: bool = True, diff --git a/docs/source/notes/distributed_training.rst b/docs/source/notes/distributed_training.rst index c64b51dca2..192167c935 100644 --- a/docs/source/notes/distributed_training.rst +++ b/docs/source/notes/distributed_training.rst @@ -18,7 +18,9 @@ performing the same work, so inspecting the rank zero is sufficient to reason about memory, performance, and other properties. Within Composer, we have three options for data-parallelism-only -execution: `Pytorch DDP`_ (default), `Pytorch FSDP`_, and `DeepSpeed Zero`_. Although Pytorch DDP is the default, DeepSpeed Zero provides better performance and lower memory utilization when configured correctly, and Pytorch FSDP increases memory and computational efficiency, while producing the same results as Pytorch DDP. +execution: `Pytorch DDP`_ (default), `Pytorch FSDP`_, and `DeepSpeed Zero`_. +Although Pytorch DDP is the default, Pytorch FSDP increases memory and computational +efficiency when configured correctly while producing the same results and is the recommended option. Usage ----- @@ -170,15 +172,28 @@ from the trainer. FullyShardedDataParallel (FSDP) ------------------------------- -Composer integrates Pytorch's `FullyShardedDataParallel `__ engine with some syntactic sugar to make it easy to write custom models that work with Composer + FSDP. +Composer integrates Pytorch's `FullyShardedDataParallel `__ +engine with some syntactic sugar to make it easy to write custom models that work with Composer + FSDP. -At a high level, when you use the Composer Trainer, you must pass it a :mod:`ComposerModel` like `ComposerGPT `__ that defines certain functions like :code:`forward`, :code:`eval_forward`, :code:`loss`, etc. that are called during the training loop. +At a high level, when you use the Composer Trainer, you must pass it a :mod:`ComposerModel` like +`ComposerGPT `__ +that defines certain functions like :code:`forward`, :code:`eval_forward`, :code:`loss`, etc. that +are called during the training loop. -Inside that :mod:`ComposerModel` you may have one or many submodules, such as a :code:`.model` or :code:`.language_model` or :code:`.classifier` that is the actual :mod:`torch.nn.Module` that you will be deploying at inference time. In our case, this is the `GPT `__ module that we build and attach :mod:`ComposerGPT.model`. +Inside that :mod:`ComposerModel` you may have one or many submodules, such as a :code:`.model` or +:code:`.language_model` or :code:`.classifier` that is the actual :mod:`torch.nn.Module` that you +will be deploying at inference time. In our case, this is the +`GPT `__ +module that we build and attach :mod:`ComposerGPT.model`. -When you provide an :code:`fsdp_config={...}` dictionary to the Composer Trainer, then on :code:`__init__`, the Trainer will attempt to wrap **each of the submodules** of your :mod:`ComposerModel` with an FSDP auto wrap policy. This wrapping is recursive, so not only is `GPT` wrapped, but all submodules of `GPT` may/may not be wrapped too. See the `FSDP documentation `__ for more details on how auto wrap policies work. +When you provide an :code:`parallelism_config={'fsdp': {...}}` dictionary to the Composer Trainer, +then on :code:`__init__`, the Trainer will attempt to wrap **each of the submodules** of your +:mod:`ComposerModel` with an FSDP auto wrap policy. This wrapping is recursive, so not only is +`GPT` wrapped, but all submodules of `GPT` may/may not be wrapped too. See the +`FSDP documentation `__ for more details on how auto +wrap policies work. -The full spec and defaults for Composer's `fsdp_config` is here: +The full spec and defaults for Composer's fsdp config is here: .. code:: python @@ -188,6 +203,8 @@ The full spec and defaults for Composer's `fsdp_config` is here: 'activation_cpu_offload': bool = True | False, # Default: False 'backward_prefetch': str = 'BACKWARD_PRE' | 'BACKWARD_POST' | 'NONE', # Default: 'BACKWARD_POST' 'cpu_offload': bool = True | False, # Default: False, cpu_offload not supported yet + 'data_parallel_shard_degree': int = -1, # Default: -1 + 'data_parallel_replicate_degree': int = 1, # Default: 1 'forward_prefetch': bool = True | False, # Default: False 'ignored_modules': Optional[Iterable[torch.nn.Module]], # Default: None 'keep_low_precision_grads': bool = True | False, # Default: False @@ -201,7 +218,6 @@ The full spec and defaults for Composer's `fsdp_config` is here: # 'reduce_dtype': 'fp32' | 'fp16' | 'bf16', # 'buffer_dtype': 'fp32' | 'fp16' | 'bf16', # }, - 'process_group': str = 'self' | 'node' | 'local_rank_across_nodes' | 'setK' | 'modK', # Default: None 'save_planner': torch.distributed.checkpoint.planner.SavePlanner, # Default: None 'sharded_ckpt_prefix_dir': str = 'ep{epoch}-ba{batch}', # Default: 'ep{epoch}-ba{batch}' 'sharding_strategy': str = 'FULL_SHARD' | 'SHARD_GRAD_OP' | 'NO_SHARD', # Default: 'FULL_SHARD' @@ -211,9 +227,17 @@ The full spec and defaults for Composer's `fsdp_config` is here: 'verbose': bool = True | False, # Default: False } -All values come with defaults and can be optionally defined in the :code:`fsdp_config`. Most parameters map directly to parameters in the `FSDP documentation `__. +All values come with defaults and can be optionally defined in the :code:`fsdp_config`. Most +parameters map directly to parameters in the +`FSDP documentation `__. +This config is passed under `parallelism_config['fsdp']` to the Composer Trainer. Two important +parameters which do not map include `data_parallel_shard_degree`, which dictates the number of +devices to shard across, and `data_parallel_replicate_degree`, which dictates the number of +devices to replicate across. -One Composer-specific pattern is that if :code:`mixed_precision` is provided as a :code:`str`, then we automatically infer the settings to use from the Trainer's :code:`precision`, which is already being used for autocast, and we construct an associated MixedPrecision object for FSDP: +One Composer-specific pattern is that if :code:`mixed_precision` is provided as a :code:`str`, +then we automatically infer the settings to use from the Trainer's :code:`precision`, which is +already being used for autocast, and we construct an associated MixedPrecision object for FSDP: .. code:: python @@ -243,7 +267,7 @@ An example code snippet for using FSDP with composer is provided below: import torch.nn as nn from composer import Trainer - class Block (nn.Module): + class Block(nn.Module): ... class Model(nn.Module): @@ -299,7 +323,7 @@ An example code snippet for using FSDP with composer is provided below: trainer = Trainer( model=composer_model, - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, ... ) @@ -371,7 +395,8 @@ A very similar auto wrap policy is provided for activation checkpointing, with a def activation_checkpointing_fn(self, module): return isinstance(module, Block) -While the user can instantiate and pass in process groups, Composer enables process groups to be specified using the following options: +While the user can instantiate and pass in process groups, Composer enables process groups to be +specified using the following options: 1. :code:`'self'`: the degenerate case where all process groups only operate within their current rank (:code:`'self'` == :code:`'set1'`). This is useful when you do not want a layer to be synchonized across accelerators. @@ -392,8 +417,9 @@ Depending on the value you set for :code:`state_dict_type`, you can get differen 1. :code:`state_dict_type='full'` The default. Saves one big checkpoint file for the whole model. It does this by gathering the model state to the global rank 0 device, unflattening it, and then saving it out. -If `load_monolith_rank0_only=True`, then when loading checkpoints the global rank 0 device will load in the checkpoint file and scatter the -model and optimizer state to the other ranks, which will will dramatically reduce the memory usage on system. Otherwise, all ranks will separately load in the checkpoint file. +If `load_monolith_rank0_only=True`, then when loading checkpoints the global rank 0 device will load +in the checkpoint file and scatter the model and optimizer state to the other ranks, which will will +dramatically reduce the memory usage on system. Otherwise, all ranks will separately load in the checkpoint file. 2. :code:`state_dict_type='sharded'` Each rank saves out an unflattened shard. For loading, each rank loads in the checkpoint file @@ -403,8 +429,9 @@ corresponding to their unflattened shard. See `The FSDP docs `__ for more info. If you use sharded checkpoints (`state_dict_type='sharded'`), your run will save as many files as you have -ranks at each checkpointing event (plus one metadata file for torch versions 2.0.0 or higher). This can quicky pollute your `save_folder` with a lot of files after a couple checkpointing events. -To help keep your checkpoint shard files organized, Composer will save each set of shards in it's own prefix directory, which you can configure +ranks at each checkpointing event (plus one metadata file for torch versions 2.0.0 or higher). This can quicky +pollute your `save_folder` with a lot of files after a couple checkpointing events. To help keep your +checkpoint shard files organized, Composer will save each set of shards in it's own prefix directory, which you can configure by using `'sharded_ckpt_prefix_dir'` (default value `sharded_ckpt_prefix_dir='ep{epoch}-ba{batch}'`). Checkpoint shards will be saved to `{save_folder} / {sharded_ckpt_prefix_dir}` @@ -415,7 +442,7 @@ For example, to save sharded checkpoints to disk locally (`state_dict_type='shar import torch.nn as nn from composer import Trainer - class Block (nn.Module): + class Block(nn.Module): ... class Model(nn.Module): @@ -464,7 +491,7 @@ For example, to save sharded checkpoints to disk locally (`state_dict_type='shar trainer = Trainer( model=composer_model, max_duration='4ba' - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, save_folder='checkpoints', save_interval='2ba', ... @@ -490,7 +517,7 @@ To load these checkpoint files, you would need to do something like this: trainer = Trainer( model=composer_model, max_duration='4ba' - fsdp_config=fsdp_config, + parallelism_config={'fsdp': fsdp_config}, load_path='./checkpoints/ba2-shards' # load_path must be the path to the prefix directory and not to a specific file. ... ) @@ -506,60 +533,108 @@ Four things to note in this load example: 4. To do multinode resumption (resuming on more than one node regardless of how many nodes you saved on), you must be using torch 2.0.1 or higher due a bug in torch 2.0.0. +Tensor Parallel (TP) +-------------------- -Saving and Loading Sharded Checkpoints with FSDP and Torch 1.13 ---------------------------------------------------------------- +Composer integrates Pytorch's `Tensor Parallel `__ +API with some syntactic sugar to make it easy to write custom models that work with Composer + TP. -To save sharded checkpoints to disk locally (`state_dict_type='sharded'`) with FSDP on PyTorch version 1.13, you must do: +To enable Tensor Parallel, a tensor parallel config must be passed to the Composer Trainer. The +full spec and defaults for Composer's tensor parallelism_config is here: .. code:: python - trainer = Trainer( - model=composer_model, - max_duration='4ba' - fsdp_config=fsdp_config, - save_folder='checkpoints', - save_filename='ba{batch}_rank{rank}.pt', - save_interval='2ba', - ... - ) + tp_config = { + tensor_parallel_degree: int = 1, # Default: 1 + pipeline_parallel_degree: int = 1, # Default: None + } - trainer.fit() +All values come with defaults and can be optionally defined in the :code:`tp_config`. Most parameters +map directly to parameters in the +`Tensor Parallel documentation `__. +This config is passed under `parallelism_config['tp']` to the Composer Trainer. An important parameters +which do not map include `tensor_parallel_degree`, which dictates the number of devices to shard across. -After the second batch, this code will save N checkpoint files to the local directory ``./checkpoints/ba2-shards``. For example, -if you trained with 4 ranks, ``./checkpoints/ba2-shards`` would contain 4 files: ``ba2_rank0.pt``, ``ba2_rank1.pt``, ``ba2_rank2.pt``, and ``ba2_rank3.pt``. -After the fourth batch, N checkpoint files (``ba4_rank0.pt``, ``ba4_rank1.pt``, etc.) will saved to ``./checkpoints/ba4-shards`` -To load these checkpoint files, you would need to do something like this: + +An example code snippet for using FSDP with composer is provided below: .. code:: python + import torch.nn as nn + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel from composer import Trainer + class Block(nn.Module): + ... + + class Model(nn.Module): + def __init__(self, n_layers): + super().__init__() + self.blocks = nn.ModuleList([ + Block(...) for _ in range(n_layers) + ]), + self.head = nn.Linear(...) + + def forward(self, inputs): + ... + + # FSDP Wrap Function + def fsdp_wrap_fn(self, module): + return isinstance(module, Block) + + # Activation Checkpointing Function + def activation_checkpointing_fn(self, module): + return isinstance(module, Block) + + + class MyComposerModel(ComposerModel): + + def __init__(self, n_layers): + super().__init__() + self.model = Model(n_layers) + ... + + def forward(self, batch): + ... + + def eval_forward(self, batch, outputs=None): + ... + + def loss(self, outputs, batch): + ... + + ... + + composer_model = MyComposerModel(n_layers=3) + fsdp_config = { 'sharding_strategy': 'FULL_SHARD', - 'state_dict_type': 'sharded', + 'cpu_offload': False, # Not supported yet + 'mixed_precision': 'DEFAULT', + 'backward_prefetch': 'BACKWARD_POST', + 'activation_checkpointing': False, + 'activation_cpu_offload': False, + 'verbose': True + } + tp_config = { + 'tensor_parallel_degree': 2, + layer_plan = { + 'model.0.fc1': ColwiseParallel(), + 'model.0.fc2': RowwiseParallel(), + } } trainer = Trainer( model=composer_model, - max_duration='4ba' - fsdp_config=fsdp_config, - load_path='./checkpoints/ba2-shards/ba2_rank{rank}.pt' + parallelism_config={'fsdp': fsdp_config}, ... ) -Three things to note in this torch 1.13 load example: - -1. Instead of setting ``load_path`` to the path to a specific file, we keep the ``{rank}`` placeholder to denote that -the file to load is different for each rank. - -2. We must set ``'state_dict_type': 'sharded'``, like we did during the save. - -3. Composer with torch 1.13 does not support elastic checkpointing (more ranks than checkpoint files or more files than ranks), so you -must make sure the number of ranks you run on during load is the same as the number you used during save (the same as the number of files). -Upgrading to torch 2.0.0 or higher will enable elastic checkpointing! + trainer.fit() +.. note:: + This is an experimental feature and is subject to change. Many features, such as `load_monolith_rank0_only` or tensor parallelism without FSDP, are not yet supported. .. _Pytorch DDP: https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html .. _Deepspeed Zero: https://www.deepspeed.ai/ diff --git a/tests/algorithms/test_sam.py b/tests/algorithms/test_sam.py index 8858d2c5ca..11e386cedc 100644 --- a/tests/algorithms/test_sam.py +++ b/tests/algorithms/test_sam.py @@ -88,22 +88,22 @@ def config(self, request): scheduler, 'precision': 'amp_bf16', - 'fsdp_config': + 'parallelism_config': None, 'deepspeed_config': None, } if distributed_mode == 'FSDP': - config_dict['fsdp_config'] = {'sharding_strategy': 'NO_SHARD'} + config_dict['parallelism_config'] = {'fsdp': {'sharding_strategy': 'NO_SHARD'}} else: config_dict['deepspeed_config'] = {'prescale_gradients': True} # Simulate world_size checking as in LLMFoundry. See: # * https://github.com/mosaicml/llm-foundry/blob/bfbb8c57053eaa3cb99a5d51ba602d1a6c872aa7/scripts/train/train.py#L519-L523 if dist.get_world_size( - ) == 1 and (config_dict['fsdp_config'] is not None or config_dict['deepspeed_config'] is not None): - config_dict['fsdp_config'] = config_dict['deepspeed_config'] = None + ) == 1 and (config_dict['parallelism_config'] is not None or config_dict['deepspeed_config'] is not None): + config_dict['parallelism_config'] = config_dict['deepspeed_config'] = None return config_dict diff --git a/tests/callbacks/test_generate.py b/tests/callbacks/test_generate.py index 6d5e5ea026..d4bbe003ef 100644 --- a/tests/callbacks/test_generate.py +++ b/tests/callbacks/test_generate.py @@ -46,7 +46,9 @@ def _create_trainer(self, device, max_duration, use_fsdp, generate_cb: Optional[ device=device, max_duration=max_duration, callbacks=generate_cb, - fsdp_config={'sharding_strategy': 'FULL_SHARD'} if use_fsdp else None, + parallelism_config={'fsdp': { + 'sharding_strategy': 'FULL_SHARD', + }} if use_fsdp else None, ) def test_no_effect_on_training(self, device, world_size, use_fsdp): diff --git a/tests/callbacks/test_optimizer_monitor.py b/tests/callbacks/test_optimizer_monitor.py index 8e19264621..523408c09c 100644 --- a/tests/callbacks/test_optimizer_monitor.py +++ b/tests/callbacks/test_optimizer_monitor.py @@ -75,15 +75,17 @@ def test_fsdp_optimizer_monitor(device, world_size, use_orig_params): train_dataloader=DataLoader(dataset, sampler=dist.get_sampler(dataset)), optimizers=DecoupledAdamW(model.parameters()), max_duration='11ba', - fsdp_config={ - 'sharding_strategy': 'FULL_SHARD' if world_size > 1 else 'NO_SHARD', - 'cpu_offload': False, - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - 'activation_checkpointing': False, - 'activation_cpu_offload': False, - 'verbose': False, - 'use_orig_params': use_orig_params, + parallelism_config={ + 'fsdp': { + 'sharding_strategy': 'FULL_SHARD' if world_size > 1 else 'NO_SHARD', + 'cpu_offload': False, + 'mixed_precision': 'PURE', + 'backward_prefetch': 'BACKWARD_PRE', + 'activation_checkpointing': False, + 'activation_cpu_offload': False, + 'verbose': False, + 'use_orig_params': use_orig_params, + }, }, ) trainer.fit() @@ -147,15 +149,17 @@ def test_fsdp_optimizer_monitor_transformer(device, world_size, tiny_gpt2_model, train_dataloader=train_dataloader, optimizers=DecoupledAdamW(model.parameters()), max_duration='11ba', - fsdp_config={ - 'sharding_strategy': 'FULL_SHARD' if world_size > 1 else 'NO_SHARD', - 'cpu_offload': False, - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - 'activation_checkpointing': False, - 'activation_cpu_offload': False, - 'verbose': False, - 'use_orig_params': use_orig_params, + parallelism_config={ + 'fsdp': { + 'sharding_strategy': 'FULL_SHARD' if world_size > 1 else 'NO_SHARD', + 'cpu_offload': False, + 'mixed_precision': 'PURE', + 'backward_prefetch': 'BACKWARD_PRE', + 'activation_checkpointing': False, + 'activation_cpu_offload': False, + 'verbose': False, + 'use_orig_params': use_orig_params, + }, }, ) trainer.fit() diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 4be8f9a348..3481fbf5ee 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -21,7 +21,9 @@ ConvModel, EmbeddedWeightTiedModel, EmptyModel, + EvenSimplerMLP, SimpleConvModel, + SimpleMLP, SimpleModel, SimpleModelWithDropout, SimpleTransformerClassifier, @@ -64,4 +66,6 @@ def get_module_subclasses(module: types.ModuleType, cls: Type) -> List[Type]: 'SimpleDataset', 'InfiniteClassificationDataset', 'composer_resnet', + 'SimpleMLP', + 'EvenSimplerMLP', ] diff --git a/tests/common/markers.py b/tests/common/markers.py index 33cf1bcd2d..7ad1529259 100644 --- a/tests/common/markers.py +++ b/tests/common/markers.py @@ -67,7 +67,7 @@ def test_something(world_size: int): if world_size == 1: parameters.append(pytest.param(1)) else: - parameters.append(pytest.param(2, marks=pytest.mark.world_size(2))) + parameters.append(pytest.param(world_size, marks=pytest.mark.world_size(world_size))) def decorator(test: Callable): if len(parameters) == 0: diff --git a/tests/common/models.py b/tests/common/models.py index ea3546ad1a..2910c37e9f 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -55,6 +55,9 @@ class SimpleModel(ComposerClassifier): Args: num_features (int): number of input features (default: 1) num_classes (int): number of classes (default: 2) + num_hidden (int): number of hidden units (default: 8) + device (str): the device to initialize the model (default: 'cpu') + bias (bool): whether or not to include bias in the linear layers (default: True) """ def __init__( @@ -102,7 +105,7 @@ def param_init_fn(self, module): class SimpleMLP(torch.nn.Module): - def __init__(self, num_features: int, device: str): + def __init__(self, num_features: int, device: str = 'cpu'): super().__init__() self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) @@ -117,7 +120,7 @@ def forward(self, x): # are not submodules of EvenSimplerMLP, like they are in SimpleMLP. class EvenSimplerMLP(torch.nn.Module): - def __init__(self, num_features: int, device: str): + def __init__(self, num_features: int, device: str = 'cpu'): super().__init__() fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) diff --git a/tests/conftest.py b/tests/conftest.py index 6e401eca0a..05841a73f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ # Important: when updating this list, make sure to also up ./.ci/test.sh # (so tests of all world sizes will be executed) and tests/README.md # (so the documentation is correct) -WORLD_SIZE_OPTIONS = (1, 2) +WORLD_SIZE_OPTIONS = (1, 2, 4) # Enforce deterministic mode before any tests start. reproducibility.configure_deterministic_mode() diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py index 53721b01d9..e00fa9ec3d 100644 --- a/tests/models/test_hf_model.py +++ b/tests/models/test_hf_model.py @@ -503,7 +503,7 @@ def get_lm_trainer( load_path: Optional[str] = None, is_conditional_generation: bool = False, do_eval: bool = False, - fsdp_config: Optional[Dict[str, Any]] = None, + parallelism_config: Optional[Dict[str, Any]] = None, mlm: bool = True, add_padding: bool = False, device_train_microbatch_size: Optional[int] = None, @@ -594,7 +594,7 @@ def get_lm_trainer( save_interval='1ep', save_filename='hf-checkpoint.pt', load_path=load_path, - fsdp_config=fsdp_config, + parallelism_config=parallelism_config, loggers=in_memory_logger, device_train_microbatch_size=batch_size if device_train_microbatch_size is None else device_train_microbatch_size, @@ -1028,17 +1028,19 @@ def test_hf_fsdp(tiny_bert_config, tiny_bert_tokenizer): tiny_bert_model = transformers.AutoModelForMaskedLM.from_config(tiny_bert_config) - fsdp_config = { - 'sharding_strategy': 'FULL_SHARD', - 'cpu_offload': False, - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - 'activation_checkpointing': False, - 'activation_cpu_offload': False, - 'verbose': False, + parallelism_config = { + 'fsdp': { + 'sharding_strategy': 'FULL_SHARD', + 'cpu_offload': False, + 'mixed_precision': 'PURE', + 'backward_prefetch': 'BACKWARD_PRE', + 'activation_checkpointing': False, + 'activation_cpu_offload': False, + 'verbose': False, + }, } - trainer = get_lm_trainer(tiny_bert_model, tiny_bert_tokenizer, None, fsdp_config=fsdp_config) + trainer = get_lm_trainer(tiny_bert_model, tiny_bert_tokenizer, None, parallelism_config=parallelism_config) assert is_model_fsdp(trainer.state.model) @@ -1222,18 +1224,16 @@ def test_generate(device, world_size, hf_model, hf_tokenizer, use_fsdp): 'GPT2 is not currently supported with DDP. See https://github.com/huggingface/transformers/issues/22482 for more details.', ) - fsdp_config = None + parallelism_config = None if use_fsdp: - fsdp_config = { - 'sharding_strategy': 'FULL_SHARD', - } + parallelism_config = {'fsdp': {'sharding_strategy': 'FULL_SHARD',}} hf_tokenizer = hf_tokenizer() model = HuggingFaceModel(hf_model, tokenizer=hf_tokenizer, use_logits=True) # just instantiating Trainer to go through the normal FSDP code path - trainer = Trainer(model=model, fsdp_config=fsdp_config, device=device) + trainer = Trainer(model=model, parallelism_config=parallelism_config, device=device) device = trainer.state.device @@ -1292,18 +1292,16 @@ def test_eval_forward_generate(device, world_size, hf_model, hf_tokenizer, use_f 'GPT2 is not currently supported with DDP. See https://github.com/huggingface/transformers/issues/22482 for more details.', ) - fsdp_config = None + parallelism_config = None if use_fsdp: - fsdp_config = { - 'sharding_strategy': 'FULL_SHARD', - } + parallelism_config = {'fsdp': {'sharding_strategy': 'FULL_SHARD',}} hf_tokenizer = hf_tokenizer() model = HuggingFaceModel(hf_model, tokenizer=hf_tokenizer, use_logits=True) # just instantiating Trainer to go through the normal FSDP code path - trainer = Trainer(model=model, fsdp_config=fsdp_config, device=device) + trainer = Trainer(model=model, parallelism_config=parallelism_config, device=device) device = trainer.state.device @@ -1490,14 +1488,10 @@ def test_peft_fsdp_trains( ): pytest.importorskip('peft') - fsdp_config = { - 'sharding_strategy': 'FULL_SHARD', - 'cpu_offload': False, - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - 'activation_checkpointing': False, - 'activation_cpu_offload': False, - 'verbose': False, + parallelism_config = { + 'fsdp': { + 'sharding_strategy': 'FULL_SHARD', + }, } stashed_model = copy.deepcopy(tiny_gpt2_model) @@ -1509,7 +1503,7 @@ def test_peft_fsdp_trains( peft_config=gpt2_peft_config, device_train_microbatch_size=1, mlm=False, - fsdp_config=fsdp_config, + parallelism_config=parallelism_config, should_save_peft_only=should_save_peft_only, ) @@ -1530,7 +1524,7 @@ def test_peft_fsdp_trains( device_train_microbatch_size=1, mlm=False, load_path=str(tmp_path / 'trainer1' / 'hf-checkpoint.pt'), - fsdp_config=fsdp_config, + parallelism_config=parallelism_config, should_save_peft_only=should_save_peft_only, ) diff --git a/tests/test_events.py b/tests/test_events.py index c007eb62fc..633457ff57 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -107,23 +107,25 @@ def test_event_calls(self, world_size, device, deepspeed_zero_stage, use_fsdp, p if deepspeed_zero_stage: deepspeed_config = {'zero_optimization': {'stage': deepspeed_zero_stage}} - fsdp_config = None + parallelism_config = None if use_fsdp: - fsdp_config = { - 'sharding_strategy': 'FULL_SHARD', - 'cpu_offload': False, - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - 'activation_checkpointing': False, - 'activation_ocpu_offload': False, - 'verbose': False, + parallelism_config = { + 'fsdp': { + 'sharding_strategy': 'FULL_SHARD', + 'cpu_offload': False, + 'mixed_precision': 'PURE', + 'backward_prefetch': 'BACKWARD_PRE', + 'activation_checkpointing': False, + 'activation_ocpu_offload': False, + 'verbose': False, + }, } trainer = self.get_trainer( precision=precision, device=device, deepspeed_config=deepspeed_config, - fsdp_config=fsdp_config, + parallelism_config=parallelism_config, save_interval=save_interval, eval_interval=save_interval, ) diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py index f0dbcde0fb..9241482456 100644 --- a/tests/trainer/test_ddp.py +++ b/tests/trainer/test_ddp.py @@ -160,16 +160,18 @@ def test_ddp(device: str, world_size: int, deepspeed: bool, fsdp: bool, tmp_path ), ) - fsdp_config = None + parallelism_config = None if fsdp: - fsdp_config = { - 'sharding_strategy': 'FULL_SHARD', - 'cpu_offload': False, - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - 'activation_checkpointing': False, - 'activation_cpu_offload': False, - 'verbose': False, + parallelism_config = { + 'fsdp': { + 'sharding_strategy': 'FULL_SHARD', + 'cpu_offload': False, + 'mixed_precision': 'PURE', + 'backward_prefetch': 'BACKWARD_PRE', + 'activation_checkpointing': False, + 'activation_cpu_offload': False, + 'verbose': False, + }, } max_epochs = 2 @@ -183,7 +185,7 @@ def test_ddp(device: str, world_size: int, deepspeed: bool, fsdp: bool, tmp_path eval_subset_num_batches=eval_subset_num_batches, train_subset_num_batches=train_subset_num_batches, deepspeed_config={} if deepspeed else None, - fsdp_config=fsdp_config, + parallelism_config=parallelism_config, callbacks=[CheckBatch0(tmp_path)], ) diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index b3061481ca..ac72313fe2 100644 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -11,7 +11,7 @@ from composer.core import State from composer.devices import DeviceCPU, DeviceGPU -from composer.trainer.dist_strategy import ddp_sync_context, prepare_ddp_module +from composer.distributed import ddp_sync_context, prepare_ddp_module from composer.utils import dist from tests.common.datasets import RandomClassificationDataset diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index b766b4d1ed..355fa21c7e 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -import contextlib from unittest.mock import MagicMock import pytest @@ -61,10 +60,12 @@ def test_fsdp_device_initialization( model=model, optimizers=optimizer, train_dataloader=dataloader, - fsdp_config={ - 'activation_checkpointing_reentrant': reentrant, - 'mixed_precision': mixed_precision, - 'sync_module_states': True if device == 'mixed' else False, + parallelism_config={ + 'fsdp': { + 'activation_checkpointing_reentrant': reentrant, + 'mixed_precision': mixed_precision, + 'sync_module_states': True if device == 'mixed' else False, + }, }, max_duration='3ba', ) @@ -126,10 +127,12 @@ def dummy_param_init_fn(module: torch.nn.Module): model=model, optimizers=optimizer, train_dataloader=dataloader, - fsdp_config={ - 'mixed_precision': 'PURE', - 'sharding_strategy': 'SHARD_GRAD_OP', - 'sync_module_states': True if device == 'mixed' else False, + parallelism_config={ + 'fsdp': { + 'mixed_precision': 'PURE', + 'sharding_strategy': 'SHARD_GRAD_OP', + 'sync_module_states': True if device == 'mixed' else False, + }, }, max_duration='3ba', ) @@ -167,10 +170,10 @@ def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precisio model=model, optimizers=optimizer, train_dataloader=dataloader, - fsdp_config={ + parallelism_config={'fsdp': { 'mixed_precision': mixed_precision, 'sharding_strategy': 'SHARD_GRAD_OP', - }, + }}, max_duration='3ba', ) @@ -191,9 +194,11 @@ def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limi model=model, optimizers=optimizer, train_dataloader=dataloader, - fsdp_config={ - 'forward_prefetch_limit': forward_prefetch_limit, - 'backward_prefetch_limit': backward_prefetch_limit, + parallelism_config={ + 'fsdp': { + 'forward_prefetch_limit': forward_prefetch_limit, + 'backward_prefetch_limit': backward_prefetch_limit, + }, }, max_duration='3ba', ) @@ -205,6 +210,7 @@ def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limi @world_size(2) @pytest.mark.filterwarnings('ignore:Instantiating FSDP with custom process groups.*:UserWarning') @pytest.mark.filterwarnings('ignore:Composer is instantiating custom process groups.*:UserWarning') +@pytest.mark.filterwarnings('ignore:.*process_group and device_mesh are set for FSDP.*.:UserWarning') def test_fsdp_process_group(world_size: int): model = SimpleModel() model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] @@ -217,8 +223,10 @@ def test_fsdp_process_group(world_size: int): model=model, optimizers=optimizer, train_dataloader=dataloader, - fsdp_config={ - 'process_group': 'mod1', # all ranks + parallelism_config={ + 'fsdp': { + 'process_group': 'mod1', # all ranks + }, }, max_duration='3ba', ) @@ -226,30 +234,6 @@ def test_fsdp_process_group(world_size: int): trainer.fit() -@pytest.mark.gpu -@world_size(2) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.2.0'), reason='Device mesh requires Torch 2.2') -@pytest.mark.parametrize( - 'sharding_strategy', - ['SHARD_GRAD_OP', 'FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'], -) -@pytest.mark.parametrize('device_mesh', [[2], [1, 2]]) -def test_wrong_size_device_mesh_error(world_size: int, sharding_strategy: str, device_mesh: list[int]): - context = contextlib.nullcontext() - if sharding_strategy in ['SHARD_GRAD_OP', 'FULL_SHARD'] and len(device_mesh) != 1: - context = pytest.raises(ValueError, match='.*requires a device mesh of size 1.*') - if sharding_strategy in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] and len(device_mesh) != 2: - context = pytest.raises(ValueError, match='.*requires a device mesh of size 2.*') - with context: - Trainer( - model=SimpleModel(), - fsdp_config={ - 'sharding_strategy': sharding_strategy, - 'device_mesh': device_mesh, - }, - ) - - class SimpleMLP(ComposerModel): def __init__(self, num_features: int = 128, device: str = 'cuda'): @@ -279,10 +263,12 @@ def test_fsdp_act_ckpt_offload( ): model = SimpleMLP() - fsdp_config = { - 'activation_checkpointing': activation_checkpointing, - 'activation_checkpointing_reentrant': False, - 'activation_cpu_offload': activation_cpu_offload, + parallelism_config = { + 'fsdp': { + 'activation_checkpointing': activation_checkpointing, + 'activation_checkpointing_reentrant': False, + 'activation_cpu_offload': activation_cpu_offload, + }, } model.fc1._activation_checkpointing = True # pyright: ignore[reportGeneralTypeIssues] @@ -290,7 +276,7 @@ def test_fsdp_act_ckpt_offload( trainer = Trainer( model=model, device='gpu', - fsdp_config=fsdp_config, + parallelism_config=parallelism_config, ) assert trainer.state.fsdp_enabled @@ -324,7 +310,7 @@ def oom_hook(*args): trainer = Trainer( model=model, - fsdp_config={}, + parallelism_config={'fsdp': {}}, max_duration='3ba', ) fsdp_model = trainer.state.model @@ -360,7 +346,7 @@ def test_fsdp_same_state_after_oom_reshard(world_size: int): trainer = Trainer( model=model, - fsdp_config={}, + parallelism_config={'fsdp': {}}, dist_timeout=20, optimizers=optimizer, seed=1, @@ -382,7 +368,7 @@ def oom_hook(module, grad_input, grad_ouput): oom_handle = oom_model.fc2.register_full_backward_hook(oom_hook) oom_trainer = Trainer( model=oom_model, - fsdp_config={}, + parallelism_config={'fsdp': {}}, dist_timeout=20, optimizers=oom_model_optimizer, seed=1, @@ -417,3 +403,54 @@ def oom_hook(module, grad_input, grad_ouput): output_2 = fsdp_oom_model(x) assert torch.equal(output_1, output_2) + + +@pytest.mark.gpu +@world_size(2) +def test_fsdp_device_mesh(world_size: int): + model = SimpleModel() + model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + + # Expect warning via pytest + with pytest.warns(DeprecationWarning): + Trainer( + model=model, + parallelism_config={'fsdp': { + 'device_mesh': [2], + }}, + max_duration='3ba', + ) + + +@pytest.mark.gpu +@world_size(2) +def test_fsdp_shard(world_size: int): + model = SimpleModel() + model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + + Trainer( + model=model, + parallelism_config={'fsdp': { + 'data_parallel_shard_degree': 2, + }}, + max_duration='3ba', + ) + + +@pytest.mark.gpu +@world_size(2) +def test_fsdp_shard_and_replicate(world_size: int): + model = SimpleModel() + model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + + Trainer( + model=model, + parallelism_config={'fsdp': { + 'data_parallel_shard_degree': 2, + 'data_parallel_replicate_degree': 1, + }}, + max_duration='3ba', + ) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 93c60d8e97..d799b3279b 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -19,6 +19,7 @@ import torch from packaging import version from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._tensor import DTensor from torch.utils.data import DataLoader from torchmetrics import Metric, MetricCollection from torchmetrics.classification import MulticlassAccuracy @@ -84,10 +85,11 @@ class FSDPConfig: sharding_strategy: str = 'FULL_SHARD' sharded_ckpt_prefix_dir: str = 'ba{batch}' sync_module_states: bool = True - use_orig_params: bool = False + use_orig_params: bool = True load_monolith_rank0_only: bool = False save_planner: Optional[Any] = None load_planner: Optional[Any] = None + data_parallel_shard_degree: int = -1 def get_trainer( @@ -95,7 +97,7 @@ def get_trainer( save_folder: Optional[str] = None, save_filename: str = 'ba{batch}-rank{rank}.pt', save_overwrite: bool = False, - num_features: int = 2, + num_features: int = 4, num_classes: int = 2, load_path: Optional[str] = None, autoresume: bool = False, @@ -112,6 +114,7 @@ def get_trainer( train_metrics: Optional[Any] = None, val_metrics: Optional[Any] = None, fsdp_config: Optional[FSDPConfig] = None, + tp_config: Optional[dict[str, Any]] = None, ): if fsdp_config is None: fsdp_config = FSDPConfig() @@ -122,7 +125,7 @@ def get_trainer( val_metrics=val_metrics, ) model.module.to(model_init_device) - dataset = RandomClassificationDataset(shape=(num_features,), size=128) + dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=128) dataloader = DataLoader( dataset, sampler=dist.get_sampler(dataset), @@ -135,12 +138,16 @@ def get_trainer( else: raise ValueError(f'Unsupported optimizer name {optimizer}') + parallelism_config = {'fsdp': dataclasses.asdict(fsdp_config)} + if tp_config is not None: + parallelism_config['tp'] = tp_config + trainer = Trainer( algorithms=algorithms, model=model, optimizers=optim, train_dataloader=dataloader, - fsdp_config=dataclasses.asdict(fsdp_config), + parallelism_config=parallelism_config, save_folder=save_folder, max_duration=max_duration, save_interval=save_interval, @@ -190,6 +197,11 @@ def _compare_optims_between_state_dicts(state_dict1, state_dict2): state_dict1_moment = state_dict1_moment.local_tensor() if isinstance(state_dict2_moment, ShardedTensor): state_dict2_moment = state_dict2_moment.local_tensor() + if isinstance(state_dict1_moment, DTensor): + state_dict1_moment = state_dict1_moment.to_local() + if isinstance(state_dict2_moment, DTensor): + state_dict2_moment = state_dict2_moment.to_local() + print(param_name, state_dict1_moment, state_dict2_moment) torch.testing.assert_close(state_dict1_moment, state_dict2_moment) @@ -213,6 +225,11 @@ def _compare_model_params_between_state_dicts(state_dict1, state_dict2): state_dict1_model_tensor = state_dict1_model_tensor.local_tensor() if isinstance(state_dict2_model_tensor, ShardedTensor): state_dict2_model_tensor = state_dict2_model_tensor.local_tensor() + if isinstance(state_dict1_model_tensor, DTensor): + state_dict1_model_tensor = state_dict1_model_tensor.to_local() + if isinstance(state_dict2_model_tensor, DTensor): + state_dict2_model_tensor = state_dict2_model_tensor.to_local() + torch.testing.assert_close(state_dict1_model_tensor, state_dict2_model_tensor) @@ -285,18 +302,19 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2): @pytest.mark.gpu -@world_size(2) @pytest.mark.filterwarnings(r'ignore:.*scatter_full_optim_state_dict``is being deprecated.*:UserWarning') @pytest.mark.parametrize( - 'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only', + 'world_size,optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp', [ - ['adam', False, 'amp_bf16', False, False, False], - ['adamw', False, 'amp_bf16', False, False, False], - ['adam', True, 'amp_bf16', False, False, False], - ['adam', False, 'amp_fp16', False, False, False], - ['adam', False, 'amp_bf16', True, True, False], # save_weights_only requires load_weights_only - ['adam', False, 'amp_bf16', False, True, False], - ['adam', False, 'amp_bf16', False, False, True], + pytest.param(2, 'adam', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adamw', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adam', True, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adam', False, 'amp_fp16', False, False, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adam', False, 'amp_bf16', True, True, False, False, + marks=pytest.mark.world_size(2)), # save_weights_only requires load_weights_only + pytest.param(2, 'adam', False, 'amp_bf16', False, True, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, 'adam', False, 'amp_bf16', False, False, True, False, marks=pytest.mark.world_size(2)), + pytest.param(4, 'adam', False, 'amp_bf16', False, False, False, True, marks=pytest.mark.world_size(4)), ], ) def test_fsdp_full_state_dict_load( @@ -308,6 +326,7 @@ def test_fsdp_full_state_dict_load( save_weights_only: bool, load_weights_only: bool, load_monolith_rank0_only: bool, + use_tp: bool, ): if autoresume: run_name = 'my-cool-autoresume-run' @@ -317,6 +336,16 @@ def test_fsdp_full_state_dict_load( save_filename = 'rank{rank}.pt' fsdp_config = FSDPConfig(load_monolith_rank0_only=load_monolith_rank0_only) + tp_config = None + if use_tp: + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + tp_config = { + 'tensor_parallel_degree': 2, + 'layer_plan': { + 'module.0': ColwiseParallel(), + 'module.2': RowwiseParallel(), + }, + } trainer1 = get_trainer( save_folder=str(save_folder), @@ -326,6 +355,7 @@ def test_fsdp_full_state_dict_load( autoresume=autoresume, optimizer=optimizer, fsdp_config=fsdp_config, + tp_config=tp_config, ) trainer1.fit() state_dict_from_trainer1 = trainer1.state.state_dict() @@ -341,7 +371,9 @@ def test_fsdp_full_state_dict_load( max_duration='4ba', optimizer=optimizer, fsdp_config=fsdp_config, + save_weights_only=save_weights_only, load_weights_only=load_weights_only, + tp_config=tp_config, ) state_dict_from_trainer2 = trainer2.state.state_dict() @@ -717,18 +749,18 @@ def mock_get_checkpoint_validation_function(): @pytest.mark.gpu -@world_size(2) @pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False]) @pytest.mark.parametrize( - 'weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink', + 'world_size,weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp', [ - [False, 'adamw', 'amp_bf16', False, None, True], - [False, 'adamw', 'amp_bf16', False, None, False], - [True, 'adamw', 'amp_bf16', False, None, False], - [False, 'adam', 'amp_bf16', False, None, False], - [False, 'adamw', 'amp_fp16', False, None, False], - [False, 'adamw', 'amp_bf16', True, None, False], - [False, 'adamw', 'amp_bf16', False, ['rng'], False], + pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, True, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adam', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adamw', 'amp_fp16', False, None, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adamw', 'amp_bf16', True, None, False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adamw', 'amp_bf16', False, ['rng'], False, False, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adamw', 'amp_bf16', False, None, True, False, marks=pytest.mark.world_size(2)), + pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, True, marks=pytest.mark.world_size(4)), ], ) @pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') @@ -743,13 +775,17 @@ def test_fsdp_partitioned_state_dict_load( weights_only: bool, load_ignore_keys: Union[list[str], None], use_symlink: bool, + use_tp: bool, use_remote, s3_bucket, s3_ephemeral_prefix, request, ): if weights_only and autoresume: - pytest.xfail('Weights only with autoresume is not supported') + pytest.skip('Weights only with autoresume is not supported') + if use_tp and version.parse(torch.__version__) < version.parse('2.3.0'): + pytest.skip('TP requires torch 2.3.0 or later') + load_ignore_keys = [] if load_ignore_keys is None else load_ignore_keys if autoresume: @@ -767,6 +803,17 @@ def test_fsdp_partitioned_state_dict_load( save_filename = 'ba{batch}-rank{rank}.pt' fsdp_config = FSDPConfig(state_dict_type='sharded') + tp_config = None + if use_tp: + fsdp_config = FSDPConfig(state_dict_type='sharded') + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + tp_config = { + 'tensor_parallel_degree': 2, + 'layer_plan': { + 'module.0': ColwiseParallel(), + 'module.2': RowwiseParallel(), + }, + } trainer1 = get_trainer( save_folder=str(save_folder), @@ -779,6 +826,7 @@ def test_fsdp_partitioned_state_dict_load( save_interval='2ba', save_weights_only=weights_only, fsdp_config=fsdp_config, + tp_config=tp_config, ) run_name = trainer1.state.run_name trainer1.fit() @@ -814,6 +862,7 @@ def test_fsdp_partitioned_state_dict_load( optimizer=optimizer, load_weights_only=weights_only, fsdp_config=fsdp_config, + tp_config=tp_config, load_ignore_keys=load_ignore_keys, ) state_dict_from_trainer2 = trainer2.state.state_dict() diff --git a/tests/trainer/test_fsdp_param_groups.py b/tests/trainer/test_fsdp_param_groups.py index b51b54c33d..7cbd52520e 100644 --- a/tests/trainer/test_fsdp_param_groups.py +++ b/tests/trainer/test_fsdp_param_groups.py @@ -38,10 +38,12 @@ def test_fsdp_param_groups_without_orig_params(mixed_precision: str, device: str model=model, optimizers=optimizer, train_dataloader=dataloader, - fsdp_config={ - 'activation_checkpointing_reentrant': reentrant, - 'mixed_precision': mixed_precision, - 'use_orig_params': False, + parallelism_config={ + 'fsdp': { + 'activation_checkpointing_reentrant': reentrant, + 'mixed_precision': mixed_precision, + 'use_orig_params': False, + }, }, max_duration='3ba', device=device, @@ -83,9 +85,11 @@ def test_fsdp_with_param_groups(mixed_precision: str, device: str, reentrant: bo model=model, optimizers=optimizer, train_dataloader=dataloader, - fsdp_config={ - 'activation_checkpointing_reentrant': reentrant, - 'mixed_precision': mixed_precision, + parallelism_config={ + 'fsdp': { + 'activation_checkpointing_reentrant': reentrant, + 'mixed_precision': mixed_precision, + }, }, max_duration='3ba', device=device, diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py new file mode 100644 index 0000000000..8146ebad40 --- /dev/null +++ b/tests/trainer/test_tp.py @@ -0,0 +1,48 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from packaging import version +from torch.utils.data import DataLoader + +from composer.trainer.trainer import Trainer +from composer.utils import dist +from tests.common import ( + RandomClassificationDataset, + SimpleModel, + world_size, +) + + +@pytest.mark.gpu +@world_size(4) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') +def test_tp_train(world_size: int): + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + + model = SimpleModel() + dataset = RandomClassificationDataset(size=8) + dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset)) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + layer_plan = { + 'fc1': ColwiseParallel(), + 'fc2': RowwiseParallel(), + } + + trainer = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + parallelism_config={ + 'tp': { + 'layer_plan': layer_plan, + 'tensor_parallel_degree': 2, + }, + 'fsdp': {}, + }, + max_duration='3ba', + ) + + trainer.fit() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 36bd02131f..c9b0073411 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -664,14 +664,16 @@ def test_fsdp( if precision == Precision.FP32: # FSDP FULL_SHARD doesn't support FP32 return - fsdp_config = { - 'sharding_strategy': 'FULL_SHARD', - 'cpu_offload': False, - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - 'activation_checkpointing': False, - 'activation_cpu_offload': False, - 'verbose': False, + parallelism_config = { + 'fsdp': { + 'sharding_strategy': 'FULL_SHARD', + 'cpu_offload': False, + 'mixed_precision': 'PURE', + 'backward_prefetch': 'BACKWARD_PRE', + 'activation_checkpointing': False, + 'activation_cpu_offload': False, + 'verbose': False, + }, } # Need to catch the case where we try to train @@ -683,7 +685,7 @@ def test_fsdp( trainer = Trainer( model=model, precision=precision, - fsdp_config=fsdp_config, + parallelism_config=parallelism_config, max_duration=max_duration, train_dataloader=train_dataloader, ) @@ -706,14 +708,16 @@ def test_fsdp_torch_compile( max_duration: Time[int], train_dataloader: DataLoader, ): - fsdp_config = { - 'sharding_strategy': 'FULL_SHARD', - 'cpu_offload': False, - 'mixed_precision': 'PURE', - 'backward_prefetch': 'BACKWARD_PRE', - 'activation_checkpointing': False, - 'activation_cpu_offload': False, - 'verbose': False, + parallelism_config = { + 'fsdp': { + 'sharding_strategy': 'FULL_SHARD', + 'cpu_offload': False, + 'mixed_precision': 'PURE', + 'backward_prefetch': 'BACKWARD_PRE', + 'activation_checkpointing': False, + 'activation_cpu_offload': False, + 'verbose': False, + }, } # Need to catch the case where we try to train @@ -725,7 +729,7 @@ def test_fsdp_torch_compile( trainer = Trainer( model=model, precision=precision, - fsdp_config=fsdp_config, + parallelism_config=parallelism_config, max_duration=max_duration, train_dataloader=train_dataloader, auto_log_hparams=True, diff --git a/tests/utils/test_autolog_hparams.py b/tests/utils/test_autolog_hparams.py index e2e8da4c20..2836d2962d 100644 --- a/tests/utils/test_autolog_hparams.py +++ b/tests/utils/test_autolog_hparams.py @@ -178,6 +178,7 @@ def test_extract_hparams_trainer(): 'deepspeed_config': None, 'fsdp_config': None, 'fsdp_auto_wrap': True, + 'parallelism_config': None, # System/Numerics 'device': 'DeviceCPU', diff --git a/tests/utils/test_inference.py b/tests/utils/test_inference.py index d152a3e1ae..36d60fd7ad 100644 --- a/tests/utils/test_inference.py +++ b/tests/utils/test_inference.py @@ -17,10 +17,10 @@ from composer.core import State from composer.devices import DeviceCPU, DeviceGPU +from composer.distributed import prepare_ddp_module from composer.functional import apply_gated_linear_units from composer.loggers import InMemoryLogger, Logger from composer.loggers.logger_destination import LoggerDestination -from composer.trainer.dist_strategy import prepare_ddp_module from composer.trainer.trainer import Trainer from composer.utils import dist, export_with_logger, inference from composer.utils.device import get_device