mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Use backend._get_all_devices()
to validate devices.
PiperOrigin-RevId: 719367913
This commit is contained in:
parent
cbc2d623fb
commit
08d81e45d4
@ -1015,7 +1015,9 @@ def _init_backend(platform: str) -> xla_client.Client:
|
||||
# factories instead of returning None.
|
||||
if backend is None:
|
||||
raise RuntimeError(f"Could not initialize backend '{platform}'")
|
||||
if backend.device_count() == 0:
|
||||
# TODO(b/356678989): Only check `backend.device_count()` when it counts
|
||||
# CPU-only devices.
|
||||
if backend.device_count() == 0 and len(backend._get_all_devices()) == 0:
|
||||
raise RuntimeError(f"Backend '{platform}' provides no devices.")
|
||||
util.distributed_debug_log(("Initialized backend", backend.platform),
|
||||
("process_index", backend.process_index()),
|
||||
|
@ -217,9 +217,15 @@ class GetBackendTest(jtu.JaxTestCase):
|
||||
def process_index(self):
|
||||
return 0
|
||||
|
||||
def devices(self):
|
||||
return []
|
||||
|
||||
def local_devices(self):
|
||||
return []
|
||||
|
||||
def _get_all_devices(self):
|
||||
return self.devices()
|
||||
|
||||
def _register_factory(self, platform: str, priority, device_count=1,
|
||||
assert_used_at_most_once=False, experimental=False):
|
||||
if assert_used_at_most_once:
|
||||
@ -306,7 +312,6 @@ class GetBackendTest(jtu.JaxTestCase):
|
||||
):
|
||||
xb.get_backend("error")
|
||||
|
||||
|
||||
def test_no_devices(self):
|
||||
self._register_factory("no_devices", -10, device_count=0)
|
||||
with self.assertRaisesRegex(
|
||||
|
Loading…
x
Reference in New Issue
Block a user