From 1dd42452d24c2bf753c517b8ed708d8b6e3a1264 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 8 May 2024 14:05:49 -0700 Subject: [PATCH] lint and pytorch pinning --- composer/core/state.py | 11 ++++++----- composer/trainer/dist_strategy.py | 3 ++- tests/trainer/test_tp.py | 17 ++++++++++------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 5cf91cc809..c61aeb0007 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -198,7 +198,8 @@ def _create_device_mesh(device: Device, fsdp_config: Optional[Dict[str, Any]], t return None # Gather dimensions and names for the device mesh - dims, names = [], [] + dims: List[int] = [] + names: List[str] = [] dims.append(fsdp_config['data_parallel_shard_degree']) names.append('dp_shard') if fsdp_config['data_parallel_replicate_degree'] != 1: @@ -219,14 +220,14 @@ def _create_device_mesh(device: Device, fsdp_config: Optional[Dict[str, Any]], t if len(unspecified_dim_names) > 1: raise ValueError( f'Found multiple parallelism dimensions with -1: {unspecified_dim_names}. ' - 'Only one is allowed, which is set to fill the remaining dimensions.' + 'Only one is allowed, which is set to fill the remaining dimensions.', ) remaining_dimension = dist.get_world_size() // product_of_dims if remaining_dimension * product_of_dims != dist.get_world_size(): raise ValueError( f'World size {dist.get_world_size()} is not divisible by the product of the specified ' 'parallelism degrees. Please ensure the product of the specified parallelism degrees ' - 'matches the world size.' + 'matches the world size.', ) for i, dim in enumerate(dims): if dim == -1: @@ -237,7 +238,7 @@ def _create_device_mesh(device: Device, fsdp_config: Optional[Dict[str, Any]], t if device_type == 'gpu': device_type = 'cuda' - return init_device_mesh(device_type=device_type, mesh_shape=dims, mesh_dim_names=names) + return init_device_mesh(device_type=device_type, mesh_shape=tuple(dims), mesh_dim_names=tuple(names)) _STATE_DICT_SERIALIZED_ATTRIBUTES = [ @@ -527,7 +528,7 @@ def __init__( raise ValueError( 'Tensor parallelism (TP) currently requires FSDP to be enabled .' 'An empty `fsdp_config` can be specified to enable FSDP with ' - 'default settings.' + 'default settings.', ) if self.load_fsdp_monolith_rank0_only: diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index b7d5fef259..b58c3ccbb7 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -18,7 +18,6 @@ ) from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp._common_utils import clean_tensor_name -from torch.distributed.tensor.parallel import parallelize_module from torch.nn.parallel import DistributedDataParallel from torchmetrics import Metric, MetricCollection @@ -186,6 +185,8 @@ def prepare_tp_module( tp_config: Dict[str, Any], ) -> None: """Prepare a module (assumed ComposerModel) for use with tensor parallel.""" + from torch.distributed.tensor.parallel import parallelize_module + device_mesh = tp_config['device_mesh'] layer_plan = tp_config['layer_plan'] parallelize_module( diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index 216444d79d..9b5a6a0016 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -3,14 +3,8 @@ import pytest import torch +from packaging import version from torch.distributed._tensor.device_mesh import init_device_mesh -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - PrepareModuleInput, - RowwiseParallel, - SequenceParallel, - parallelize_module, -) from torch.utils.data import DataLoader from composer.trainer.trainer import Trainer @@ -24,7 +18,16 @@ @pytest.mark.gpu @world_size(2) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') def test_tp_train(world_size: int): + from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module, + ) + model = SimpleModel() dataset = RandomClassificationDataset(size=10) dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))