diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 4d32f787e..cbbd04da4 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -1015,7 +1015,9 @@ def _init_backend(platform: str) -> xla_client.Client: # factories instead of returning None. if backend is None: raise RuntimeError(f"Could not initialize backend '{platform}'") - if backend.device_count() == 0: + # TODO(b/356678989): Only check `backend.device_count()` when it counts + # CPU-only devices. + if backend.device_count() == 0 and len(backend._get_all_devices()) == 0: raise RuntimeError(f"Backend '{platform}' provides no devices.") util.distributed_debug_log(("Initialized backend", backend.platform), ("process_index", backend.process_index()), diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 509a4244d..97e8765cc 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -217,9 +217,15 @@ class GetBackendTest(jtu.JaxTestCase): def process_index(self): return 0 + def devices(self): + return [] + def local_devices(self): return [] + def _get_all_devices(self): + return self.devices() + def _register_factory(self, platform: str, priority, device_count=1, assert_used_at_most_once=False, experimental=False): if assert_used_at_most_once: @@ -306,7 +312,6 @@ class GetBackendTest(jtu.JaxTestCase): ): xb.get_backend("error") - def test_no_devices(self): self._register_factory("no_devices", -10, device_count=0) with self.assertRaisesRegex(