diff --git a/src/open_clip_train/distributed.py b/src/open_clip_train/distributed.py index 0d3a71d3e..2fad34575 100644 --- a/src/open_clip_train/distributed.py +++ b/src/open_clip_train/distributed.py @@ -127,6 +127,12 @@ def init_distributed_device_so( global_rank = 0 local_rank = 0 device_type, *device_idx = device.split(':', maxsplit=1) + is_avail, is_known = is_device_available(device_type) + if not is_known: + warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.") + elif not is_avail: + warnings.warn(f"Device {device} was not available, falling back to CPU.") + device_type = device = 'cpu' if horovod: import horovod.torch as hvd @@ -172,13 +178,6 @@ def init_distributed_device_so( global_rank = torch.distributed.get_rank() distributed = True - is_avail, is_known = is_device_available(device_type) - if not is_known: - warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.") - elif not is_avail: - warnings.warn(f"Device {device} was not available, falling back to CPU.") - device_type = device = 'cpu' - if distributed and not no_set_device_rank and device_type not in ('cpu', 'mps'): # Ignore manually specified device index in distributed mode and # override with resolved local rank, fewer headaches in most setups.