Skip to content

Commit

Permalink
lint and pytorch pinning
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed May 8, 2024
1 parent 3db87d2 commit 1dd4245
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
11 changes: 6 additions & 5 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def _create_device_mesh(device: Device, fsdp_config: Optional[Dict[str, Any]], t
return None

# Gather dimensions and names for the device mesh
dims, names = [], []
dims: List[int] = []
names: List[str] = []
dims.append(fsdp_config['data_parallel_shard_degree'])
names.append('dp_shard')
if fsdp_config['data_parallel_replicate_degree'] != 1:
Expand All @@ -219,14 +220,14 @@ def _create_device_mesh(device: Device, fsdp_config: Optional[Dict[str, Any]], t
if len(unspecified_dim_names) > 1:
raise ValueError(
f'Found multiple parallelism dimensions with -1: {unspecified_dim_names}. '
'Only one is allowed, which is set to fill the remaining dimensions.'
'Only one is allowed, which is set to fill the remaining dimensions.',
)
remaining_dimension = dist.get_world_size() // product_of_dims
if remaining_dimension * product_of_dims != dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} is not divisible by the product of the specified '
'parallelism degrees. Please ensure the product of the specified parallelism degrees '
'matches the world size.'
'matches the world size.',
)
for i, dim in enumerate(dims):
if dim == -1:
Expand All @@ -237,7 +238,7 @@ def _create_device_mesh(device: Device, fsdp_config: Optional[Dict[str, Any]], t
if device_type == 'gpu':
device_type = 'cuda'

return init_device_mesh(device_type=device_type, mesh_shape=dims, mesh_dim_names=names)
return init_device_mesh(device_type=device_type, mesh_shape=tuple(dims), mesh_dim_names=tuple(names))


_STATE_DICT_SERIALIZED_ATTRIBUTES = [
Expand Down Expand Up @@ -527,7 +528,7 @@ def __init__(
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP to be enabled .'
'An empty `fsdp_config` can be specified to enable FSDP with '
'default settings.'
'default settings.',
)

if self.load_fsdp_monolith_rank0_only:
Expand Down
3 changes: 2 additions & 1 deletion composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp._common_utils import clean_tensor_name
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Metric, MetricCollection

Expand Down Expand Up @@ -186,6 +185,8 @@ def prepare_tp_module(
tp_config: Dict[str, Any],
) -> None:
"""Prepare a module (assumed ComposerModel) for use with tensor parallel."""
from torch.distributed.tensor.parallel import parallelize_module

device_mesh = tp_config['device_mesh']
layer_plan = tp_config['layer_plan']
parallelize_module(
Expand Down
17 changes: 10 additions & 7 deletions tests/trainer/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@

import pytest
import torch
from packaging import version
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
from torch.utils.data import DataLoader

from composer.trainer.trainer import Trainer
Expand All @@ -24,7 +18,16 @@

@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+')
def test_tp_train(world_size: int):
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)

model = SimpleModel()
dataset = RandomClassificationDataset(size=10)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
Expand Down

0 comments on commit 1dd4245

Please sign in to comment.