This is useful in the case of ahead of time compilation, when libtpu is present but there may not be any TPU chips, so we shouldn't attempt to initialize a TPU backend.
PiperOrigin-RevId: 630055511
```
$ JAX_USE_PJRT_C_API_ON_TPU=0 python3 -c "import jax"
/home/skyewm/jax/jax/_src/cloud_tpu_init.py:77: UserWarning: JAX_USE_PJRT_C_API_ON_TPU=0 will no longer be supported in an upcoming future release. Please file an issue at https://github.com/google/jax/issues if you need this setting.
warnings.warn(
```
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)
To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
This used to be necessary. However, now these are automatically set in
libtpu. Beyond being redundant, the Python logic needs to be updated
to avoid getting KeyErrors on new topologies and TPU versions, so
better to remove it.
This also moves `get_metadata` to cloud_tpu_cluster.py since it's only
used in that file now.
We frequently use the pattern
try:
import m
except ImportError:
# do something else.
This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
1. In cloud_tpu_init.py, check whether we're on a Cloud TPU VM by
looking for the libtpu Python package, instead of /lib/libtpu.so
(which isn't necessarily present in a docker container). JAX now
relies on the libtpu package instead of the system libtpu.so, so
this makes more sense either way. This means we'll try/catch an
ImportError in all non-TPU environments when importing jax, which
hopefully isn't noticeably slow.
2. Add requests as a jax[tpu] dependency, since it's needed by
cloud_tpu_init.py. This comes pre-installed on Cloud TPU VMs, but
may not be installed in docker containers, virtualenvs, etc.
I manually tested by creating the following Dockerfile on a Cloud TPU VM:
```
FROM ubuntu:18.04
RUN apt update && apt install git python3-pip -y
RUN git clone https://github.com/skye/jax && cd jax && git checkout tpu_docker
WORKDIR jax
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
CMD ["python3", "-c", "import jax; print(jax.device_count())"]
```
And then running the following commands:
```
$ sudo docker build -t jax-test .
$ sudo docker run --privileged jax-test
8
```
Note the `--privileged` flags is necessary to let the container access
the TPU devices in /dev.
Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
PiperOrigin-RevId: 361681134