Merge pull request #22703 from Rifur13:plugin-fix

PiperOrigin-RevId: 657283607
This commit is contained in:
jax authors 2024-07-29 12:10:42 -07:00
commit f070c0658f
2 changed files with 15 additions and 6 deletions

View File

@ -36,6 +36,11 @@ _TPU_ENHANCED_BARRIER_SUPPORTED = [
'0x005e',
]
_NVIDIA_GPU_DEVICES = [
'/dev/nvidia0',
'/dev/dxg', # WSL2
]
def num_available_tpu_chips_and_device_id():
"""Returns the device id and number of TPU chips attached through PCI."""
num_chips = 0
@ -57,3 +62,9 @@ def tpu_enhanced_barrier_supported() -> bool:
"""Returns if tpu_enhanced_barrier flag is supported on this TPU version."""
_, device_id = num_available_tpu_chips_and_device_id()
return device_id in _TPU_ENHANCED_BARRIER_SUPPORTED
def has_visible_nvidia_gpu() -> bool:
"""True if there's a visible nvidia gpu available on device, False otherwise."""
return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES)

View File

@ -881,6 +881,9 @@ def backends() -> dict[str, xla_client.Client]:
default_priority = -1000
for platform, priority, fail_quietly in platform_registrations:
try:
if platform == "cuda" and not hardware_utils.has_visible_nvidia_gpu():
continue
backend = _init_backend(platform)
_backends[platform] = backend
@ -918,12 +921,7 @@ def _suggest_missing_backends():
assert _default_backend is not None
default_platform = _default_backend.platform
nvidia_gpu_devices = [
"/dev/nvidia0",
"/dev/dxg", # WSL2
]
if ("cuda" not in _backends and
any(os.path.exists(d) for d in nvidia_gpu_devices)):
if "cuda" not in _backends and hardware_utils.has_visible_nvidia_gpu():
if hasattr(xla_extension, "GpuAllocatorConfig") and "cuda" in _backend_errors:
err = _backend_errors["cuda"]
warning_msg = f"CUDA backend failed to initialize: {err}."