From ba972f020738f31b7a4d80b5412b3f2cd646a356 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 22 Jun 2021 23:31:12 +0000 Subject: [PATCH] On Cloud TPU, use pip-installed libtpu instead of system default if applicable. --- jax/_src/cloud_tpu_init.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index b53ebf0a3..3b33e2399 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -21,9 +21,9 @@ def cloud_tpu_init(): as JAX's C++ backend is loaded! I.e. call this before xla_bridge or xla_client is imported.** - These environment variables are used to tell the TPU runtime what kind of mesh - topology to use. It assumes a single-host topology by default, so we manually - set them here to default to the full pod slice if applicable. + Some of these environment variables are used to tell the TPU runtime what kind + of mesh topology to use. It assumes a single-host topology by default, so we + manually set them here to default to the full pod slice if applicable. This will not set any env vars if a single topology-related env var is already set. @@ -31,6 +31,17 @@ def cloud_tpu_init(): if not _running_in_cloud_tpu_vm(): return + # Use pip-installed libtpu if applicable, rather than system default. + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + import libtpu + # pytype: enable=import-error + # pylint: enable=import-outside-toplevel + libtpu.configure_library_path() + except ImportError: + pass + os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') # If the user has set any topology-related env vars, don't set any