Use backend._get_all_devices() to validate devices.

PiperOrigin-RevId: 719367913
This commit is contained in:
Yang Chen 2025-01-24 11:08:39 -08:00 committed by jax authors
parent cbc2d623fb
commit 08d81e45d4
2 changed files with 9 additions and 2 deletions

View File

@ -1015,7 +1015,9 @@ def _init_backend(platform: str) -> xla_client.Client:
# factories instead of returning None.
if backend is None:
raise RuntimeError(f"Could not initialize backend '{platform}'")
if backend.device_count() == 0:
# TODO(b/356678989): Only check `backend.device_count()` when it counts
# CPU-only devices.
if backend.device_count() == 0 and len(backend._get_all_devices()) == 0:
raise RuntimeError(f"Backend '{platform}' provides no devices.")
util.distributed_debug_log(("Initialized backend", backend.platform),
("process_index", backend.process_index()),

View File

@ -217,9 +217,15 @@ class GetBackendTest(jtu.JaxTestCase):
def process_index(self):
return 0
def devices(self):
return []
def local_devices(self):
return []
def _get_all_devices(self):
return self.devices()
def _register_factory(self, platform: str, priority, device_count=1,
assert_used_at_most_once=False, experimental=False):
if assert_used_at_most_once:
@ -306,7 +312,6 @@ class GetBackendTest(jtu.JaxTestCase):
):
xb.get_backend("error")
def test_no_devices(self):
self._register_factory("no_devices", -10, device_count=0)
with self.assertRaisesRegex(