mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Check "jax_rocm_visible_devices" at client creation.
This aligns rocm with cuda when using jax.distributed in combination with one of the mechanisms for cluster-autodetection that set visible devices in the "jax_rocm_visible_devices" flag. Fixes #26298
This commit is contained in:
parent
cc830748bf
commit
6abc76c874
@ -644,8 +644,9 @@ def _options_from_jax_configs(plugin_name):
|
||||
"Should be in format 'key:value'")
|
||||
options[option_list[0]] = option_list[1]
|
||||
|
||||
if plugin_name == "cuda":
|
||||
visible_devices = CUDA_VISIBLE_DEVICES.value
|
||||
if plugin_name in ("cuda", "rocm"):
|
||||
visible_devices = (CUDA_VISIBLE_DEVICES.value if plugin_name == "cuda"
|
||||
else _ROCM_VISIBLE_DEVICES.value)
|
||||
if visible_devices != 'all':
|
||||
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
|
||||
mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None
|
||||
|
Loading…
x
Reference in New Issue
Block a user