Check "jax_rocm_visible_devices" at client creation.

This aligns rocm with cuda when using jax.distributed in combination
with one of the mechanisms for cluster-autodetection that set visible
devices in the "jax_rocm_visible_devices" flag.

Fixes #26298
This commit is contained in:
Sebastian Kehl 2025-02-04 13:39:21 +01:00
parent cc830748bf
commit 6abc76c874

View File

@ -644,8 +644,9 @@ def _options_from_jax_configs(plugin_name):
"Should be in format 'key:value'")
options[option_list[0]] = option_list[1]
if plugin_name == "cuda":
visible_devices = CUDA_VISIBLE_DEVICES.value
if plugin_name in ("cuda", "rocm"):
visible_devices = (CUDA_VISIBLE_DEVICES.value if plugin_name == "cuda"
else _ROCM_VISIBLE_DEVICES.value)
if visible_devices != 'all':
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None