Skip to content

Commit

Permalink
A few more distributed devicec handling tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Oct 21, 2024
1 parent b84cf9b commit 0965313
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/open_clip_train/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def init_distributed_device(args):
device=getattr(args, 'device', 'cuda'),
dist_backend=getattr(args, 'dist_backend', None),
dist_url=getattr(args, 'dist_url', None),
horovod=args.horovod,
no_set_device_rank=args.no_set_device_rank,
horovod=getattr(args, 'horovod', False),
no_set_device_rank=getattr(args, 'no_set_device_rank', False),
)
args.device = result['device']
args.world_size = result['world_size']
Expand All @@ -128,17 +128,6 @@ def init_distributed_device_so(
local_rank = 0
device_type, *device_idx = device.split(':', maxsplit=1)

if dist_backend is None:
dist_backends = {
"xpu": "ccl",
"hpu": "hccl",
"cuda": "nccl",
"npu": "hccl",
}
dist_backend = dist_backends.get(device_type, 'gloo')
dist_url = dist_url or 'env://'

# TBD, support horovod?
if horovod:
import horovod.torch as hvd
assert hvd is not None, "Horovod is not installed"
Expand All @@ -148,6 +137,17 @@ def init_distributed_device_so(
world_size = hvd.size()
distributed = True
elif is_using_distributed():
if dist_backend is None:
dist_backends = {
"cuda": "nccl",
"hpu": "hccl",
"npu": "hccl",
"xpu": "ccl",
}
dist_backend = dist_backends.get(device_type, 'gloo')

dist_url = dist_url or 'env://'

if 'SLURM_PROCID' in os.environ:
# DDP via SLURM
local_rank, global_rank, world_size = world_info_from_env()
Expand Down

0 comments on commit 0965313

Please sign in to comment.