diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py index 7fa1c0a53..dd3da5c4f 100644 --- a/jax/_src/hardware_utils.py +++ b/jax/_src/hardware_utils.py @@ -38,6 +38,7 @@ _TPU_ENHANCED_BARRIER_SUPPORTED = [ _NVIDIA_GPU_DEVICES = [ '/dev/nvidia0', + '/dev/nvidiactl', # Docker/Kubernetes '/dev/dxg', # WSL2 ]