Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallelism Integration #3269

Merged
merged 89 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
ee61da1
v1
mvpatel2000 May 8, 2024
b93364d
add test
mvpatel2000 May 8, 2024
9c38204
fix test
mvpatel2000 May 8, 2024
a761be4
v1
May 8, 2024
21d67bd
some lint
mvpatel2000 May 8, 2024
5353727
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 8, 2024
a5b4ef7
lint
May 8, 2024
3db87d2
merge
May 8, 2024
1dd4245
lint and pytorch pinning
mvpatel2000 May 8, 2024
d38c254
results
mvpatel2000 May 8, 2024
2c4cfc3
tweak warnings
May 8, 2024
bb74919
fix lint
May 8, 2024
8ebc28a
lint
mvpatel2000 May 8, 2024
e8253f5
Merge branch 'mvpatel2000/nd-parallelism' of github.com-mvpatel2000:m…
mvpatel2000 May 8, 2024
3cdf544
filter
May 9, 2024
e491ebc
Merge branch 'mvpatel2000/nd-parallelism' of github.com-mvpatel2000:m…
May 9, 2024
1ec42c6
add ckpt
May 9, 2024
d74082b
checkdown
May 13, 2024
a47d3df
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 13, 2024
da05e5d
fix tests
May 13, 2024
5d628d0
fix
May 13, 2024
998da6b
fix test
May 13, 2024
ee185e1
fix tests
May 13, 2024
8b6664c
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 15, 2024
50ff5bb
fix tests
May 16, 2024
3964e87
Merge branch 'mvpatel2000/nd-parallelism' of github.com-mvpatel2000:m…
May 16, 2024
6f6d0d9
paralleli config
May 16, 2024
b150a33
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 16, 2024
69111b4
fix some arg parsing
May 16, 2024
2883d40
Merge branch 'mvpatel2000/nd-parallelism' of github.com-mvpatel2000:m…
May 16, 2024
5057ba1
rename to parallelism config
May 16, 2024
36dab25
lint
May 16, 2024
675bf6d
fix docs
May 16, 2024
3b0e8b8
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 16, 2024
4c70715
fix edge case
May 16, 2024
d68557e
Merge branch 'mvpatel2000/nd-parallelism' of github.com-mvpatel2000:m…
May 16, 2024
aba7964
lint
May 16, 2024
1fa7fbd
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 17, 2024
73d9126
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 18, 2024
02f1494
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 20, 2024
9015bbe
log
May 20, 2024
8c2b7e6
change slicing
May 20, 2024
4df1cfb
fix patching
May 20, 2024
d7c3668
fix core
May 20, 2024
da30aca
lint
May 20, 2024
f192812
clean up v1
May 20, 2024
cb50a10
lint
May 20, 2024
67e19d8
shallow copy
May 20, 2024
a86ca84
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 20, 2024
1b23477
add checks
May 20, 2024
50c456b
device mesh
May 20, 2024
1ab4bf3
fix type checking
May 20, 2024
cea1d70
fix bugs
May 20, 2024
8ac1900
fix tests
May 21, 2024
74ef898
rename variables and fix checkpointing
May 21, 2024
7dc3978
lint
May 21, 2024
39ed363
lint
May 21, 2024
547fb08
v1 refacotr
May 21, 2024
3916a3b
lint
May 21, 2024
760cb2d
tests
May 21, 2024
806090d
add enum
May 21, 2024
3ebaf34
fix test
May 21, 2024
00ce8c6
fix
May 21, 2024
6b08e81
fix fsdp submesh order
May 21, 2024
cb0c604
fix tests
May 22, 2024
f21c6ae
change to world size
May 22, 2024
223f58e
add docs v1
May 22, 2024
18c97b8
fix docs
May 22, 2024
84e06c2
add tests
May 22, 2024
f0f2045
v1 of tp test
May 22, 2024
037cd9d
fix lint
May 22, 2024
886a252
fix arg prop
May 22, 2024
5968bea
fix tests
May 22, 2024
9d97611
lint
May 22, 2024
773e29d
brians comments
May 23, 2024
bfa0300
pr review
May 23, 2024
03ae706
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 23, 2024
122a756
add more gating
May 24, 2024
6346ad3
lint
May 24, 2024
59ec49a
fix
May 24, 2024
8ee01c6
fix lint
May 24, 2024
a249166
fix lint
May 24, 2024
b26d2ec
lint
May 24, 2024
9a36d71
force to run
May 24, 2024
fc76734
fix assert
May 24, 2024
78de92a
parallelism config
May 24, 2024
135f729
Merge branch 'dev' into mvpatel2000/nd-parallelism
mvpatel2000 May 24, 2024
a85b7b6
tweak some parallelism issues
May 24, 2024
33a2e71
remove assert
May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,9 @@ def _save_checkpoint(self, state: State, logger: Logger):
is_deepspeed,
keep_placeholders=True,
).lstrip('/')
assert state.sharded_ckpt_prefix_dir is not None
remote_prefix = state.sharded_ckpt_prefix_dir
assert state.fsdp_config is not None
remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir']
assert remote_prefix is not None
ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename)
remote_file_name = format_name_with_dist_and_time(remote_file_name, state.run_name, state.timestamp)
Expand Down
106 changes: 52 additions & 54 deletions composer/core/state.py
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -539,28 +539,72 @@ def __init__(
parallelism_config = parallelism_config or {}
self.fsdp_config = parallelism_config.get('fsdp', None)
self.tp_config = parallelism_config.get('tp', None)
if self.fsdp_config is not None:
from composer.distributed import patch_pytorch

# Add an earlier call to patch_pytorch as we require device_mesh slicing before any
# model wrapping.
patch_pytorch()
self._validate_parallelism_configs()

self.device_mesh: Optional[DeviceMesh] = _create_device_mesh(self.device, self.fsdp_config, self.tp_config)
if self.fsdp_config is not None and self.device_mesh is not None:
fsdp_mesh_dim_names = []
if self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names:
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
self.fsdp_config['device_mesh'] = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
if self.tp_config is not None and self.device_mesh is not None:
self.tp_config['device_mesh'] = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value]

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()

# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
# All other attributes will not be serialized.
# For simplicity, omit the leading underscore for private attributes.
# For example, even though the optimizers are stored on the state
# as the "_optimizers" attribute, here we specify just "optimizers"
self.serialized_attributes = [
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
'rank_zero_seed',
'train_metrics',
'eval_metrics',
'run_name',
'dataset_state',
]

self.train_metrics: Optional[Dict[str, Metric]] = {}
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
self.eval_metrics: Dict[str, Dict[str, Metric]] = {}
self.train_metric_values: Dict[str, float] = {}
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}

self.metric_outputs: Dict[str, Any] = {}

def _validate_parallelism_configs(self):
# Validate TP config
if self.tp_config is not None:
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn('Tensor parallelism (TP) is experimental and may change in future versions.', FutureWarning)
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('Tensor parallelism (TP) requires torch>=2.3.0.')
if self.fsdp_config is None:
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
'Tensor parallelism (TP) currently requires FSDP to be enabled. '
'An empty `fsdp_config` can be specified to enable FSDP with '
'default settings.',
'default settings. Additionally, PyTorch currently errors if FSDP '
'data_parallel_shard_degree is not at least 2.',
)
if not self.fsdp_config['use_orig_params']:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, '
'which is the default and recommended setting.',
)

# Load monolith rank0 only
if self.load_monolith_rank0_only:
if self.tp_config is not None:
raise ValueError('load_fsdp_monolith_rank0_only is not compatible with tensor parallelism (TP).')
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -573,7 +617,7 @@ def __init__(
)
# Broadcast rank 0 meta check to all ranks so error can be raised on all ranks
rank0_on_meta = 0
if dist.get_global_rank() == 0 and next(model.parameters()).device.type == 'meta':
if dist.get_global_rank() == 0 and next(self.model.parameters()).device.type == 'meta':
rank0_on_meta = 1
rank0_on_meta_tensor = self.device.tensor_to_device(torch.tensor([rank0_on_meta], dtype=torch.uint8))
dist.all_reduce(rank0_on_meta_tensor, reduce_operation='MAX')
Expand All @@ -586,10 +630,7 @@ def __init__(
if error_message != '':
raise ValueError(error_message)

self.sharded_ckpt_prefix_dir: Optional[str] = None
if self.fsdp_config is not None:
self.sharded_ckpt_prefix_dir = self.fsdp_config['sharded_ckpt_prefix_dir']

# Validate FSDP state dict type
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
if self.fsdp_state_dict_type == 'local':
raise ValueError(
Expand All @@ -613,49 +654,6 @@ def __init__(
),
)

self.device_mesh: Optional[DeviceMesh] = _create_device_mesh(self.device, self.fsdp_config, self.tp_config)
if self.fsdp_config is not None and self.device_mesh is not None:
fsdp_mesh_dim_names = []
if self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names:
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
self.fsdp_config['device_mesh'] = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore
if self.tp_config is not None and self.device_mesh is not None:
self.tp_config['device_mesh'] = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value]

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()

# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
# All other attributes will not be serialized.
# For simplicity, omit the leading underscore for private attributes.
# For example, even though the optimizers are stored on the state
# as the "_optimizers" attribute, here we specify just "optimizers"
self.serialized_attributes = [
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
'rank_zero_seed',
'train_metrics',
'eval_metrics',
'run_name',
'dataset_state',
]

self.train_metrics: Optional[Dict[str, Metric]] = {}
self.eval_metrics: Dict[str, Dict[str, Metric]] = {}
self.train_metric_values: Dict[str, float] = {}
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}

self.metric_outputs: Dict[str, Any] = {}

def _dataset_of(self, dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]:
"""Get the dataset contained by the given dataloader-like object.

Expand Down
3 changes: 1 addition & 2 deletions composer/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
prepare_fsdp_module,
prepare_tp_module,
)
from composer.distributed.mosaic_fsdp import patch_pytorch, set_fsdp_default
from composer.distributed.mosaic_fsdp import set_fsdp_default

__all__ = [
'fix_batch_precision_for_deepspeed',
Expand All @@ -21,6 +21,5 @@
'prepare_ddp_module',
'prepare_fsdp_module',
'prepare_tp_module',
'patch_pytorch',
'set_fsdp_default',
]
3 changes: 0 additions & 3 deletions composer/distributed/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
SHARDING_MAP,
get_cpu_offload,
get_mixed_precision,
patch_pytorch,
set_custom_fsdp_module_kwargs,
)
from composer.utils import StringEnum, dist, ensure_tuple
Expand Down Expand Up @@ -216,8 +215,6 @@ def prepare_fsdp_module(
auto_microbatching (bool, optional): Whether or not auto microbatching is enabled.
te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234.
"""
patch_pytorch()

# Check sync_module_states is True for mixed initialization or HSDP
if fsdp_config['sync_module_states'] == False:
rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0
Expand Down
Loading
Loading