From 582c56a7070ff972ac66848e245592103596c948 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 2 May 2024 07:21:41 -0700 Subject: [PATCH] Change determination of cloud TPU to check for TPU chips. This is useful in the case of ahead of time compilation, when libtpu is present but there may not be any TPU chips, so we shouldn't attempt to initialize a TPU backend. PiperOrigin-RevId: 630055511 --- jax/_src/cloud_tpu_init.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index e12080027..2cccbd301 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -54,11 +54,10 @@ def cloud_tpu_init() -> None: """ global running_in_cloud_tpu_vm - # We assume we are in a correctly-configured Cloud TPU environment - # if the following hold: a) libtpu is installed b) JAX_FORCE_TPU_INIT is set - # Exit early if we're not running on Cloud TPU. + # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. libtpu_module = maybe_import_libtpu() - if libtpu_module is None and not jax_force_tpu_init(): + num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0] + if (libtpu_module is None or num_tpu_chips == 0) and not jax_force_tpu_init(): return running_in_cloud_tpu_vm = True