From c82349aacf486776254c0157966750833268e668 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 20 Oct 2024 21:16:11 -0700 Subject: [PATCH] Move device check ahead of dist check --- src/open_clip_train/distributed.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/open_clip_train/distributed.py b/src/open_clip_train/distributed.py index 0d3a71d3e..2fad34575 100644 --- a/src/open_clip_train/distributed.py +++ b/src/open_clip_train/distributed.py @@ -127,6 +127,12 @@ def init_distributed_device_so( global_rank = 0 local_rank = 0 device_type, *device_idx = device.split(':', maxsplit=1) + is_avail, is_known = is_device_available(device_type) + if not is_known: + warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.") + elif not is_avail: + warnings.warn(f"Device {device} was not available, falling back to CPU.") + device_type = device = 'cpu' if horovod: import horovod.torch as hvd @@ -172,13 +178,6 @@ def init_distributed_device_so( global_rank = torch.distributed.get_rank() distributed = True - is_avail, is_known = is_device_available(device_type) - if not is_known: - warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.") - elif not is_avail: - warnings.warn(f"Device {device} was not available, falling back to CPU.") - device_type = device = 'cpu' - if distributed and not no_set_device_rank and device_type not in ('cpu', 'mps'): # Ignore manually specified device index in distributed mode and # override with resolved local rank, fewer headaches in most setups.