On Cloud TPU, use pip-installed libtpu instead of system default if applicable.

This commit is contained in:
Skye Wanderman-Milne 2021-06-22 23:31:12 +00:00
parent 3b84f85700
commit ba972f0207

View File

@ -21,9 +21,9 @@ def cloud_tpu_init():
as JAX's C++ backend is loaded! I.e. call this before xla_bridge or xla_client
is imported.**
These environment variables are used to tell the TPU runtime what kind of mesh
topology to use. It assumes a single-host topology by default, so we manually
set them here to default to the full pod slice if applicable.
Some of these environment variables are used to tell the TPU runtime what kind
of mesh topology to use. It assumes a single-host topology by default, so we
manually set them here to default to the full pod slice if applicable.
This will not set any env vars if a single topology-related env var is already
set.
@ -31,6 +31,17 @@ def cloud_tpu_init():
if not _running_in_cloud_tpu_vm():
return
# Use pip-installed libtpu if applicable, rather than system default.
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
import libtpu
# pytype: enable=import-error
# pylint: enable=import-outside-toplevel
libtpu.configure_library_path()
except ImportError:
pass
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
# If the user has set any topology-related env vars, don't set any