Skip to content

Commit

Permalink
Do not destroy PG as it conflicts with hydra sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Feb 12, 2024
1 parent 7b1bfbd commit bdc37d8
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions pytorch_toolbelt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,19 @@ def __init__(
)

def __enter__(self):
if self.dist_is_available and self.world_size > 1:
if self.dist_is_initialized:
raise RuntimeError("Torch distributed is already initialized. This indicates an error.")
torch.cuda.set_device(self.device)

torch.cuda.set_device(self.device)
logger.info(f"Setting CUDA device {self.device} for rank {self.local_rank}/{self.world_size}")
torch.distributed.init_process_group(backend="nccl", world_size=self.world_size, rank=self.local_rank)
if self.dist_is_available and self.world_size > 1:
if not self.dist_is_initialized:
logger.info(f"Setting CUDA device {self.device} for rank {self.local_rank}/{self.world_size}")
torch.distributed.init_process_group(backend="nccl", world_size=self.world_size, rank=self.local_rank)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
try:
if self.dist_is_available and self.dist_is_initialized:
if self.dist_is_available:
torch.distributed.barrier()
torch.distributed.destroy_process_group()
# torch.distributed.destroy_process_group()
except Exception as e:
logger.exception(e)
finally:
Expand Down

0 comments on commit bdc37d8

Please sign in to comment.