diff --git a/CHANGELOG.md b/CHANGELOG.md index 93824e625a..267ce71e47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added gradient clipping to StaticCapture utilities. - Bistride Multiscale MeshGraphNet example. - FIGConvUNet model and example. +- Ability to pass additional kwargs to torch.distributed from + DistributedManager.initialize - The Transolver model. - The XAeroNet model. - Incoporated CorrDiff-GEFS-HRRR model into CorrDiff, with lead-time aware SongUNet and diff --git a/modulus/distributed/manager.py b/modulus/distributed/manager.py index 61bc2687e9..582b34c9cc 100644 --- a/modulus/distributed/manager.py +++ b/modulus/distributed/manager.py @@ -243,7 +243,7 @@ def get_available_backend(): return "gloo" @staticmethod - def initialize_env(): + def initialize_env(**kwargs): """Setup method using generic initialization""" rank = int(os.environ.get("RANK")) world_size = int(os.environ.get("WORLD_SIZE")) @@ -268,10 +268,11 @@ def initialize_env(): addr=addr, port=port, backend=DistributedManager.get_available_backend(), + **kwargs, ) @staticmethod - def initialize_open_mpi(addr, port): + def initialize_open_mpi(addr, port, **kwargs): """Setup method using OpenMPI initialization""" rank = int(os.environ.get("OMPI_COMM_WORLD_RANK")) world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE")) @@ -285,10 +286,11 @@ def initialize_open_mpi(addr, port): port=port, backend=DistributedManager.get_available_backend(), method="openmpi", + **kwargs, ) @staticmethod - def initialize_slurm(port): + def initialize_slurm(port, **kwargs): """Setup method using SLURM initialization""" rank = int(os.environ.get("SLURM_PROCID")) world_size = int(os.environ.get("SLURM_NPROCS")) @@ -303,10 +305,11 @@ def initialize_slurm(port): port=port, backend=DistributedManager.get_available_backend(), method="slurm", + **kwargs, ) @staticmethod - def initialize(): + def initialize(**kwargs): """ Initialize distributed manager @@ -324,6 +327,9 @@ def initialize(): listed above. Initialization method can also be explicitly controlled using the `MODULUS_DISTRIBUTED_INITIALIZATION_METHOD` environment variable and setting it to one of the options above. + + kwargs are passed down to torch.distributed.init_process_group directly. This can be used + to set parameters like `timeout=timedelta(minutes=60)` """ if DistributedManager.is_initialized(): warn("Distributed manager is already intialized") @@ -336,23 +342,23 @@ def initialize(): initialization_method = os.getenv("MODULUS_DISTRIBUTED_INITIALIZATION_METHOD") if initialization_method is None: try: - DistributedManager.initialize_env() + DistributedManager.initialize_env(**kwargs) except TypeError: if "SLURM_PROCID" in os.environ: - DistributedManager.initialize_slurm(port) + DistributedManager.initialize_slurm(port, **kwargs) elif "OMPI_COMM_WORLD_RANK" in os.environ: - DistributedManager.initialize_open_mpi(addr, port) + DistributedManager.initialize_open_mpi(addr, port, **kwargs) else: warn( "Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job" ) DistributedManager._shared_state["_is_initialized"] = True elif initialization_method == "ENV": - DistributedManager.initialize_env() + DistributedManager.initialize_env(**kwargs) elif initialization_method == "SLURM": - DistributedManager.initialize_slurm(port) + DistributedManager.initialize_slurm(port, **kwargs) elif initialization_method == "OPENMPI": - DistributedManager.initialize_open_mpi(addr, port) + DistributedManager.initialize_open_mpi(addr, port, **kwargs) else: raise RuntimeError( "Unknown initialization method " @@ -374,6 +380,7 @@ def setup( port="12355", backend="nccl", method="env", + **kwargs, ): """Set up PyTorch distributed process group and update manager attributes""" os.environ["MASTER_ADDR"] = addr @@ -404,6 +411,7 @@ def setup( rank=manager.rank, world_size=manager.world_size, device_id=manager.device, + **kwargs, ) except TypeError: # device_id only introduced in PyTorch 2.3 @@ -411,6 +419,7 @@ def setup( backend, rank=manager.rank, world_size=manager.world_size, + **kwargs, ) if torch.cuda.is_available(): diff --git a/test/distributed/test_manager.py b/test/distributed/test_manager.py index 97cfd652f8..3be15c22d1 100644 --- a/test/distributed/test_manager.py +++ b/test/distributed/test_manager.py @@ -15,6 +15,7 @@ # limitations under the License. import os +from datetime import timedelta import pytest import torch @@ -55,6 +56,29 @@ def test_manager(): del os.environ["WORLD_SIZE"] +def test_manager_timeout(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12345" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + DistributedManager.initialize(timeout=timedelta(minutes=60)) + print(DistributedManager()) + + manager = DistributedManager() + + assert manager.is_initialized() + assert ( + manager.distributed == torch.distributed.is_available() + ), "Manager should be in serial mode" + assert manager.rank == 0 + assert manager.world_size == 1 + assert manager.local_rank == 0 + + DistributedManager.cleanup() + del os.environ["RANK"] + del os.environ["WORLD_SIZE"] + + def test_manager_slurm(): # Test distributed manager with Slurm variables os.environ["MASTER_ADDR"] = "localhost"