Skip to content

Commit

Permalink
Get and force the use of the device local to process in distributed runs
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 24, 2025
1 parent d57c8ea commit e09ed4d
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def __init__(self) -> None:
process_id=self._rank,
local_device_ids=self._local_rank,
)
# get the device local to process
try:
self._device = jax.local_devices(process_index=self._rank)[0]
logger.info(f"Using device local to process with index/rank {self._rank} ({self._device})")
except Exception as e:
logger.warning(f"Failed to get the device local to process with index/rank {self._rank}: {e}")

@staticmethod
def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
Expand All @@ -204,6 +210,15 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
"""
import jax

# force the use of the device local to process in distributed runs
if config.jax.is_distributed:
try:
return jax.local_devices(process_index=config.jax.rank)[0]
except Exception as e:
logger.warning(
f"Failed to get the device local to process with index/rank {config.jax.rank}: {e}"
)

if isinstance(device, jax.Device):
return device
elif isinstance(device, str):
Expand Down

0 comments on commit e09ed4d

Please sign in to comment.