Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose all torch.distributed.init_process_group parameters in the DistributedManager #690

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added gradient clipping to StaticCapture utilities.
- Bistride Multiscale MeshGraphNet example.
- FIGConvUNet model and example.
- Ability to pass additional kwargs to torch.distributed from
DistributedManager.initialize

### Changed

Expand Down
29 changes: 19 additions & 10 deletions modulus/distributed/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def get_available_backend():
return "gloo"

@staticmethod
def initialize_env():
def initialize_env(**kwargs):
"""Setup method using generic initialization"""
rank = int(os.environ.get("RANK"))
world_size = int(os.environ.get("WORLD_SIZE"))
Expand All @@ -268,10 +268,11 @@ def initialize_env():
addr=addr,
port=port,
backend=DistributedManager.get_available_backend(),
**kwargs,
)

@staticmethod
def initialize_open_mpi(addr, port):
def initialize_open_mpi(addr, port, **kwargs):
"""Setup method using OpenMPI initialization"""
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK"))
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE"))
Expand All @@ -285,10 +286,11 @@ def initialize_open_mpi(addr, port):
port=port,
backend=DistributedManager.get_available_backend(),
method="openmpi",
**kwargs,
)

@staticmethod
def initialize_slurm(port):
def initialize_slurm(port, **kwargs):
"""Setup method using SLURM initialization"""
rank = int(os.environ.get("SLURM_PROCID"))
world_size = int(os.environ.get("SLURM_NPROCS"))
Expand All @@ -303,10 +305,11 @@ def initialize_slurm(port):
port=port,
backend=DistributedManager.get_available_backend(),
method="slurm",
**kwargs,
)

@staticmethod
def initialize():
def initialize(**kwargs):
"""
Initialize distributed manager

Expand All @@ -324,6 +327,9 @@ def initialize():
listed above. Initialization method can also be explicitly controlled using the
`MODULUS_DISTRIBUTED_INITIALIZATION_METHOD` environment variable and setting it
to one of the options above.

kwargs are passed down to torch.distributed.init_process_group directly. This can be used
to set parameters like `timeout=timedelta(minutes=60)`
"""
if DistributedManager.is_initialized():
warn("Distributed manager is already intialized")
Expand All @@ -336,23 +342,23 @@ def initialize():
initialization_method = os.getenv("MODULUS_DISTRIBUTED_INITIALIZATION_METHOD")
if initialization_method is None:
try:
DistributedManager.initialize_env()
DistributedManager.initialize_env(**kwargs)
except TypeError:
if "SLURM_PROCID" in os.environ:
DistributedManager.initialize_slurm(port)
DistributedManager.initialize_slurm(port, **kwargs)
elif "OMPI_COMM_WORLD_RANK" in os.environ:
DistributedManager.initialize_open_mpi(addr, port)
DistributedManager.initialize_open_mpi(addr, port, **kwargs)
else:
warn(
"Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job"
)
DistributedManager._shared_state["_is_initialized"] = True
elif initialization_method == "ENV":
DistributedManager.initialize_env()
DistributedManager.initialize_env(**kwargs)
elif initialization_method == "SLURM":
DistributedManager.initialize_slurm(port)
DistributedManager.initialize_slurm(port, **kwargs)
elif initialization_method == "OPENMPI":
DistributedManager.initialize_open_mpi(addr, port)
DistributedManager.initialize_open_mpi(addr, port, **kwargs)
else:
raise RuntimeError(
"Unknown initialization method "
Expand All @@ -374,6 +380,7 @@ def setup(
port="12355",
backend="nccl",
method="env",
**kwargs,
):
"""Set up PyTorch distributed process group and update manager attributes"""
os.environ["MASTER_ADDR"] = addr
Expand Down Expand Up @@ -404,13 +411,15 @@ def setup(
rank=manager.rank,
world_size=manager.world_size,
device_id=manager.device,
**kwargs,
)
except TypeError:
# device_id only introduced in PyTorch 2.3
dist.init_process_group(
backend,
rank=manager.rank,
world_size=manager.world_size,
**kwargs,
)

if torch.cuda.is_available():
Expand Down
24 changes: 24 additions & 0 deletions test/distributed/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import os
from datetime import timedelta

import pytest
import torch
Expand Down Expand Up @@ -55,6 +56,29 @@ def test_manager():
del os.environ["WORLD_SIZE"]


def test_manager_timeout():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
DistributedManager.initialize(timeout=timedelta(minutes=60))
print(DistributedManager())

manager = DistributedManager()

assert manager.is_initialized()
assert (
manager.distributed == torch.distributed.is_available()
), "Manager should be in serial mode"
assert manager.rank == 0
assert manager.world_size == 1
assert manager.local_rank == 0

DistributedManager.cleanup()
del os.environ["RANK"]
del os.environ["WORLD_SIZE"]


def test_manager_slurm():
# Test distributed manager with Slurm variables
os.environ["MASTER_ADDR"] = "localhost"
Expand Down