diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 3b33e2399..13253d4c3 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -15,12 +15,14 @@ import os def cloud_tpu_init(): - """Automatically sets Cloud TPU topology env vars. + """Automatically sets Cloud TPU topology and other env vars. **This must be called before the TPU runtime is loaded, which happens as soon as JAX's C++ backend is loaded! I.e. call this before xla_bridge or xla_client is imported.** + Safe to call in non-Cloud TPU environments. + Some of these environment variables are used to tell the TPU runtime what kind of mesh topology to use. It assumes a single-host topology by default, so we manually set them here to default to the full pod slice if applicable. @@ -28,20 +30,18 @@ def cloud_tpu_init(): This will not set any env vars if a single topology-related env var is already set. """ - if not _running_in_cloud_tpu_vm(): - return - - # Use pip-installed libtpu if applicable, rather than system default. try: # pylint: disable=import-outside-toplevel # pytype: disable=import-error import libtpu # pytype: enable=import-error # pylint: enable=import-outside-toplevel - libtpu.configure_library_path() except ImportError: - pass + # We assume libtpu is installed iff we're in a correctly-configured Cloud + # TPU environment. Exit early if we're not running on Cloud TPU. + return + libtpu.configure_library_path() os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') # If the user has set any topology-related env vars, don't set any @@ -56,7 +56,6 @@ def cloud_tpu_init(): ]): return - # Don't assume non-Cloud TPU environments have requests installed # pylint: disable=import-outside-toplevel # pytype: disable=import-error import requests @@ -98,7 +97,3 @@ def cloud_tpu_init(): os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = worker_network_endpoints.split( ',')[0].split(':')[2] + ':8476' os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476' - - -def _running_in_cloud_tpu_vm(): - return os.path.isfile('/lib/libtpu.so') diff --git a/setup.py b/setup.py index 836d1b607..146fc8c79 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,9 @@ setup( # Cloud TPU VM jaxlib can be installed via: # $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html 'tpu': [f'jaxlib=={_current_jaxlib_version}', - f'libtpu-nightly=={_libtpu_version}'], + f'libtpu-nightly=={_libtpu_version}', + # Required by cloud_tpu_init.py + 'requests'], # CUDA installations require adding jax releases URL; e.g. # $ pip install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html