Add environment variable check in additional to libtpu check for cloud tpu vm

PiperOrigin-RevId: 504588621
This commit is contained in:
jax authors 2023-01-25 09:50:56 -08:00
parent d14e144651
commit e6e513a6e9

View File

@ -16,6 +16,25 @@ import os
running_in_cloud_tpu_vm = False
def maybe_import_libtpu():
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
import libtpu
# pytype: enable=import-error
# pylint: enable=import-outside-toplevel
except ImportError:
return None
else:
return libtpu
def jax_force_tpu_init() -> bool:
return 'JAX_FORCE_TPU_INIT' in os.environ
def cloud_tpu_init():
"""Automatically sets Cloud TPU topology and other env vars.
@ -33,20 +52,18 @@ def cloud_tpu_init():
set.
"""
global running_in_cloud_tpu_vm
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
import libtpu
# pytype: enable=import-error
# pylint: enable=import-outside-toplevel
except ImportError:
# We assume libtpu is installed iff we're in a correctly-configured Cloud
# TPU environment. Exit early if we're not running on Cloud TPU.
# We assume we are in a correctly-configured Cloud TPU environment
# if the following hold: a) libtpu is installed b) JAX_FORCE_TPU_INIT is set
# Exit early if we're not running on Cloud TPU.
libtpu_module = maybe_import_libtpu()
if libtpu_module is not None:
libtpu_module.configure_library_path()
elif not jax_force_tpu_init():
return
running_in_cloud_tpu_vm = True
libtpu.configure_library_path()
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'