This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.
Second attempt, this time without hardening against the flags being
registered too late due to delayed imports.
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.
PiperOrigin-RevId: 384902895
This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.
PiperOrigin-RevId: 384897270
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.
PiperOrigin-RevId: 384892199
1. In cloud_tpu_init.py, check whether we're on a Cloud TPU VM by
looking for the libtpu Python package, instead of /lib/libtpu.so
(which isn't necessarily present in a docker container). JAX now
relies on the libtpu package instead of the system libtpu.so, so
this makes more sense either way. This means we'll try/catch an
ImportError in all non-TPU environments when importing jax, which
hopefully isn't noticeably slow.
2. Add requests as a jax[tpu] dependency, since it's needed by
cloud_tpu_init.py. This comes pre-installed on Cloud TPU VMs, but
may not be installed in docker containers, virtualenvs, etc.
I manually tested by creating the following Dockerfile on a Cloud TPU VM:
```
FROM ubuntu:18.04
RUN apt update && apt install git python3-pip -y
RUN git clone https://github.com/skye/jax && cd jax && git checkout tpu_docker
WORKDIR jax
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
CMD ["python3", "-c", "import jax; print(jax.device_count())"]
```
And then running the following commands:
```
$ sudo docker build -t jax-test .
$ sudo docker run --privileged jax-test
8
```
Note the `--privileged` flags is necessary to let the container access
the TPU devices in /dev.
Previously, the libtpu-nightly wheels were included in the same index
file as the jaxlib wheels (jax_releases.html). This caused issues
because it would cause `pip install jax[tpu] -f jaxlib_releases.html`
to install a cuda jaxlib, instead of the regular CPU/TPU jaxlib from
pypi.
Instead, we create a separate index file for the libtpu-nightly
wheels, so `pip install jax[tpu] -f libtpu_releases.html` still uses
the jaxlib from pypi.
This also renames generate_release_index.py to generate_release_indexes.py.