mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Change determination of cloud TPU to check for TPU chips.
This is useful in the case of ahead of time compilation, when libtpu is present but there may not be any TPU chips, so we shouldn't attempt to initialize a TPU backend. PiperOrigin-RevId: 630055511
This commit is contained in:
parent
f4cf69f118
commit
582c56a707
@ -54,11 +54,10 @@ def cloud_tpu_init() -> None:
|
||||
"""
|
||||
global running_in_cloud_tpu_vm
|
||||
|
||||
# We assume we are in a correctly-configured Cloud TPU environment
|
||||
# if the following hold: a) libtpu is installed b) JAX_FORCE_TPU_INIT is set
|
||||
# Exit early if we're not running on Cloud TPU.
|
||||
# Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed.
|
||||
libtpu_module = maybe_import_libtpu()
|
||||
if libtpu_module is None and not jax_force_tpu_init():
|
||||
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
|
||||
if (libtpu_module is None or num_tpu_chips == 0) and not jax_force_tpu_init():
|
||||
return
|
||||
|
||||
running_in_cloud_tpu_vm = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user