mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add environment variable check in additional to libtpu check for cloud tpu vm
PiperOrigin-RevId: 504588621
This commit is contained in:
parent
d14e144651
commit
e6e513a6e9
@ -16,6 +16,25 @@ import os
|
||||
|
||||
running_in_cloud_tpu_vm = False
|
||||
|
||||
|
||||
def maybe_import_libtpu():
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pytype: disable=import-error
|
||||
import libtpu
|
||||
|
||||
# pytype: enable=import-error
|
||||
# pylint: enable=import-outside-toplevel
|
||||
except ImportError:
|
||||
return None
|
||||
else:
|
||||
return libtpu
|
||||
|
||||
|
||||
def jax_force_tpu_init() -> bool:
|
||||
return 'JAX_FORCE_TPU_INIT' in os.environ
|
||||
|
||||
|
||||
def cloud_tpu_init():
|
||||
"""Automatically sets Cloud TPU topology and other env vars.
|
||||
|
||||
@ -33,20 +52,18 @@ def cloud_tpu_init():
|
||||
set.
|
||||
"""
|
||||
global running_in_cloud_tpu_vm
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pytype: disable=import-error
|
||||
import libtpu
|
||||
# pytype: enable=import-error
|
||||
# pylint: enable=import-outside-toplevel
|
||||
except ImportError:
|
||||
# We assume libtpu is installed iff we're in a correctly-configured Cloud
|
||||
# TPU environment. Exit early if we're not running on Cloud TPU.
|
||||
|
||||
# We assume we are in a correctly-configured Cloud TPU environment
|
||||
# if the following hold: a) libtpu is installed b) JAX_FORCE_TPU_INIT is set
|
||||
# Exit early if we're not running on Cloud TPU.
|
||||
libtpu_module = maybe_import_libtpu()
|
||||
if libtpu_module is not None:
|
||||
libtpu_module.configure_library_path()
|
||||
elif not jax_force_tpu_init():
|
||||
return
|
||||
|
||||
running_in_cloud_tpu_vm = True
|
||||
|
||||
libtpu.configure_library_path()
|
||||
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
|
||||
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
|
||||
os.environ['TPU_ML_PLATFORM'] = 'JAX'
|
||||
|
Loading…
x
Reference in New Issue
Block a user