From 08d81e45d44b37caaf388143d623c03837482a17 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Fri, 24 Jan 2025 11:08:39 -0800 Subject: [PATCH] Use `backend._get_all_devices()` to validate devices. PiperOrigin-RevId: 719367913 --- jax/_src/xla_bridge.py | 4 +++- tests/xla_bridge_test.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 4d32f787e..cbbd04da4 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -1015,7 +1015,9 @@ def _init_backend(platform: str) -> xla_client.Client: # factories instead of returning None. if backend is None: raise RuntimeError(f"Could not initialize backend '{platform}'") - if backend.device_count() == 0: + # TODO(b/356678989): Only check `backend.device_count()` when it counts + # CPU-only devices. + if backend.device_count() == 0 and len(backend._get_all_devices()) == 0: raise RuntimeError(f"Backend '{platform}' provides no devices.") util.distributed_debug_log(("Initialized backend", backend.platform), ("process_index", backend.process_index()), diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 509a4244d..97e8765cc 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -217,9 +217,15 @@ class GetBackendTest(jtu.JaxTestCase): def process_index(self): return 0 + def devices(self): + return [] + def local_devices(self): return [] + def _get_all_devices(self): + return self.devices() + def _register_factory(self, platform: str, priority, device_count=1, assert_used_at_most_once=False, experimental=False): if assert_used_at_most_once: @@ -306,7 +312,6 @@ class GetBackendTest(jtu.JaxTestCase): ): xb.get_backend("error") - def test_no_devices(self): self._register_factory("no_devices", -10, device_count=0) with self.assertRaisesRegex(