From 8d7b3b7b9b66c362867d725f484bdef0b7954109 Mon Sep 17 00:00:00 2001 From: Nicholas Cilfone Date: Tue, 19 Oct 2021 15:33:18 -0400 Subject: [PATCH 1/2] borrows deepspeeds mpi detection for DDP which allows non deepspeed DDP to be easily used across multiple nodes when launched with openMPI (or the MPIOperator on k8s). removes DataLoader batch size option as this can be directly pulled from the Stoke object. --- README.md | 11 +++++++++-- docs/Launchers.md | 28 +++++++++++++++++++++++++--- docs/Quick-Start.md | 2 +- examples/cifar10/train.py | 2 -- stoke/configs.py | 4 ++++ stoke/distributed.py | 12 ++++++++++++ stoke/fp16.py | 1 - stoke/stoke.py | 7 ++----- 8 files changed, 53 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 6c6c8db..f38d872 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,10 @@ and are only conditionally imported). Follow the instructions [here](https://github.com/NVIDIA/apex#quick-start). -### (Optional) OpenMPI Support +### (Optional) Underlying OpenMPI Support + +**Note: MPI support is necessary if you plan to run Stoke across multiple compute nodes (e.g. 2 nodes with 4 GPUs each) +with DDP, Horovod, or DeepSpeed backends** Follow the instructions [here](https://www.open-mpi.org/faq/?category=building) or [here](https://edu.itp.phys.ethz.ch/hs12/programming_techniques/openmpi.pdf) @@ -64,6 +67,10 @@ pip install stoke ``` ### via PyPi w/ Optional MPI Support + +**Note: MPI support is necessary if you plan to run Stoke across multiple compute nodes (e.g. 2 nodes with 4 GPUs each) +with DDP, Horovod, or DeepSpeed backends** + ```bash pip install stoke[mpi] ``` @@ -211,10 +218,10 @@ sampler = DistributedSampler( ) # Call the DataLoader method on the stoke_obj to correctly create a DataLoader instance +# The DataLoader object already known the batch size from the Stoke object creation data_loader = stoke_obj.DataLoader( dataset=dataset, collate_fn=lambda batch: dataset.collate_fn(batch), - batch_size=32, sampler=sampler, num_workers=4 ) diff --git a/docs/Launchers.md b/docs/Launchers.md index 708a514..b4b425b 100644 --- a/docs/Launchers.md +++ b/docs/Launchers.md @@ -44,9 +44,31 @@ mpirun -np 16 \ ``` ### Deepspeed w/ OpenMPI -Prefer the OpenMPI version [ -here](https://www.deepspeed.ai/getting-started/#multi-node-environment-variables) over the native -launcher. Deepspeed will automatically discover devices, etc. via mpi4py. +Prefer the OpenMPI version [here](https://www.deepspeed.ai/getting-started/#multi-node-environment-variables) over the +native launcher. Deepspeed will automatically discover devices, etc. via mpi4py. Can also be used +with k8s via the [MPI Operator](https://github.com/kubeflow/mpi-operator) + +```shell +mpirun -np 4 \ + --allow-run-as-root -bind-to none -map-by slot \ + -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH \ + -mca pml ob1 -mca btl ^openib \ + python train.py +``` +or +```shell +mpirun -np 16 \ + -H server1:4,server2:4,server3:4,server4:4 \ + -bind-to none -map-by slot \ + -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH \ + -mca pml ob1 -mca btl ^openib \ + python train.py +``` + + +### PyTorch DDP w/ OpenMPI +Leverage Deepspeed functionality to automatically discover devices, etc. via mpi4py. Can also be used +with k8s via the [MPI Operator](https://github.com/kubeflow/mpi-operator) ```shell mpirun -np 4 \ diff --git a/docs/Quick-Start.md b/docs/Quick-Start.md index 0e09e9f..60aed16 100644 --- a/docs/Quick-Start.md +++ b/docs/Quick-Start.md @@ -140,10 +140,10 @@ sampler = DistributedSampler( ) # Call the DataLoader method on the stoke_obj to correctly create a DataLoader instance +# The DataLoader object already known the batch size from the Stoke object creation data_loader = stoke_obj.DataLoader( dataset=dataset, collate_fn=lambda batch: dataset.collate_fn(batch), - batch_size=32, sampler=sampler, num_workers=4 ) diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py index 3272c31..78d5909 100644 --- a/examples/cifar10/train.py +++ b/examples/cifar10/train.py @@ -147,7 +147,6 @@ def main(): # Construct the DataLoader train_loader = cifar_stoke.DataLoader( dataset=training_dataset, - batch_size=configs.DataConfig.batch_size, sampler=train_sampler, num_workers=configs.DataConfig.n_workers if configs.DataConfig.n_workers is not None @@ -165,7 +164,6 @@ def main(): ) test_loader = cifar_stoke.DataLoader( dataset=test_dataset, - batch_size=configs.DataConfig.batch_size, sampler=test_sampler, num_workers=configs.DataConfig.n_workers if configs.DataConfig.n_workers is not None diff --git a/stoke/configs.py b/stoke/configs.py index 3f17393..4e8dddb 100644 --- a/stoke/configs.py +++ b/stoke/configs.py @@ -135,6 +135,9 @@ class DDPConfig: ---------- local_rank: Optional[int] Current local rank of the device (provided here, as LOCAL_RANK env var, or parsed from --local_arg) + auto_mpi_discovery: bool, default: False + if distributed environment variables are not set, attempt to discover them from MPI (using underlying deepspeed + function call) convert_to_sync_batch_norm: bool, default: False Automatically convert all batch norm calls to torch.nn.SyncBatchNorm calls https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html @@ -168,6 +171,7 @@ class DDPConfig: """ local_rank: Optional[int] + auto_mpi_discovery: bool = False convert_to_sync_batch_norm: bool = False backend: BackendOptions = "nccl" broadcast_buffers: bool = True diff --git a/stoke/distributed.py b/stoke/distributed.py index 5639410..84e0e4e 100644 --- a/stoke/distributed.py +++ b/stoke/distributed.py @@ -12,6 +12,7 @@ from typing import List, Optional, Tuple, Union import deepspeed as ds +from deepspeed.utils.distributed import mpi_discovery import horovod.torch as hvd import torch from fairscale.optim.oss import OSS @@ -490,11 +491,22 @@ def _create_ddp_handler(kwargs: dict): def _call_init(self): """Does any backend initialization work related to DDP setup + Borrows code from DeepSpeed to setup DDP via openMPI + https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/distributed.py + Returns ------- None """ + # Borrowing a bit of code from deepspeed + required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + if self._ddp_config.auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): + try: + from mpi4py import MPI + mpi_discovery(verbose=True) + except ImportError as e: + print(e, ": mpi4py cannot be imported -- please install Stoke with the MPI option (pip install stoke[mpi])") # Initialize call for DDP torch.distributed.init_process_group( backend=self._ddp_config.backend, init_method=self._ddp_config.init_method diff --git a/stoke/fp16.py b/stoke/fp16.py index 3ffd868..439f394 100644 --- a/stoke/fp16.py +++ b/stoke/fp16.py @@ -460,7 +460,6 @@ def _apex_convert_to_sync_batch_norm(self, model: torch.nn.Module): ) try: from apex.parallel import convert_syncbn_model - model = convert_syncbn_model(module=model) except ImportError as e: print( diff --git a/stoke/stoke.py b/stoke/stoke.py index fcc1ac9..4ef1ab8 100644 --- a/stoke/stoke.py +++ b/stoke/stoke.py @@ -736,7 +736,6 @@ def _get_fp16_mixin(self): def DataLoader( self, dataset: Dataset[T_co], - batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, @@ -764,8 +763,6 @@ def DataLoader( ---------- dataset: Dataset dataset from which to load the data. - batch_size: int, default: 1 - how many samples per batch to load . shuffle: bool, default: False set to ``True`` to have the data reshuffled at every epoch. sampler: Sampler or Iterable, default: None @@ -817,12 +814,12 @@ def DataLoader( if self._verbose and self.gpu: print(f"Automatically handling moving model input data to GPU(s)...") - + # Forward the already known options from the Stoke status return StokeDataLoader( gpu=self.gpu, fp16=self.fp16, + batch_size=self.batch_size, dataset=dataset, - batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, From a5eb779d45534c75c3e6748530d4cb8f8b53934b Mon Sep 17 00:00:00 2001 From: Nicholas Cilfone Date: Wed, 20 Oct 2021 08:48:19 -0400 Subject: [PATCH 2/2] linted --- stoke/distributed.py | 20 ++++++++++++++++---- stoke/fp16.py | 1 + 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/stoke/distributed.py b/stoke/distributed.py index 84e0e4e..9848dcb 100644 --- a/stoke/distributed.py +++ b/stoke/distributed.py @@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union import deepspeed as ds -from deepspeed.utils.distributed import mpi_discovery import horovod.torch as hvd import torch +from deepspeed.utils.distributed import mpi_discovery from fairscale.optim.oss import OSS from stoke.configs import ClipGradConfig, ClipGradNormConfig @@ -500,13 +500,25 @@ def _call_init(self): """ # Borrowing a bit of code from deepspeed - required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] - if self._ddp_config.auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): + required_env = [ + "RANK", + "WORLD_SIZE", + "MASTER_ADDR", + "MASTER_PORT", + "LOCAL_RANK", + ] + if self._ddp_config.auto_mpi_discovery and not all( + map(lambda v: v in os.environ, required_env) + ): try: from mpi4py import MPI + mpi_discovery(verbose=True) except ImportError as e: - print(e, ": mpi4py cannot be imported -- please install Stoke with the MPI option (pip install stoke[mpi])") + print( + e, + ": mpi4py cannot be imported -- please install Stoke with the MPI option (pip install stoke[mpi])", + ) # Initialize call for DDP torch.distributed.init_process_group( backend=self._ddp_config.backend, init_method=self._ddp_config.init_method diff --git a/stoke/fp16.py b/stoke/fp16.py index 439f394..3ffd868 100644 --- a/stoke/fp16.py +++ b/stoke/fp16.py @@ -460,6 +460,7 @@ def _apex_convert_to_sync_batch_norm(self, model: torch.nn.Module): ) try: from apex.parallel import convert_syncbn_model + model = convert_syncbn_model(module=model) except ImportError as e: print(