diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index b53ebf0a3..3b33e2399 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -21,9 +21,9 @@ def cloud_tpu_init(): as JAX's C++ backend is loaded! I.e. call this before xla_bridge or xla_client is imported.** - These environment variables are used to tell the TPU runtime what kind of mesh - topology to use. It assumes a single-host topology by default, so we manually - set them here to default to the full pod slice if applicable. + Some of these environment variables are used to tell the TPU runtime what kind + of mesh topology to use. It assumes a single-host topology by default, so we + manually set them here to default to the full pod slice if applicable. This will not set any env vars if a single topology-related env var is already set. @@ -31,6 +31,17 @@ def cloud_tpu_init(): if not _running_in_cloud_tpu_vm(): return + # Use pip-installed libtpu if applicable, rather than system default. + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + import libtpu + # pytype: enable=import-error + # pylint: enable=import-outside-toplevel + libtpu.configure_library_path() + except ImportError: + pass + os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') # If the user has set any topology-related env vars, don't set any