Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
1. In cloud_tpu_init.py, check whether we're on a Cloud TPU VM by
looking for the libtpu Python package, instead of /lib/libtpu.so
(which isn't necessarily present in a docker container). JAX now
relies on the libtpu package instead of the system libtpu.so, so
this makes more sense either way. This means we'll try/catch an
ImportError in all non-TPU environments when importing jax, which
hopefully isn't noticeably slow.
2. Add requests as a jax[tpu] dependency, since it's needed by
cloud_tpu_init.py. This comes pre-installed on Cloud TPU VMs, but
may not be installed in docker containers, virtualenvs, etc.
I manually tested by creating the following Dockerfile on a Cloud TPU VM:
```
FROM ubuntu:18.04
RUN apt update && apt install git python3-pip -y
RUN git clone https://github.com/skye/jax && cd jax && git checkout tpu_docker
WORKDIR jax
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
CMD ["python3", "-c", "import jax; print(jax.device_count())"]
```
And then running the following commands:
```
$ sudo docker build -t jax-test .
$ sudo docker run --privileged jax-test
8
```
Note the `--privileged` flags is necessary to let the container access
the TPU devices in /dev.
* Updates jax_releases.html index to include libtpu wheels
* Change [tpu] extras to specify `libtpu-nightly` instead of wheel URL
The full install command will now be:
`pip install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html`
(similar to the cuda install commands)
I've already pushed an updated jax_releases.html to the jax-releases GCS bucket.
This can be used on Cloud TPU VMs to automatically install compatible
versions of jax, jaxlib, and libtpu (the low-level library JAX uses to
access the TPU on Cloud TPU VMs).
The new install command requires a new jax release (`>=0.2.15`) and
jaxlib release (`>=0.1.68`) to work, since it requires both
cdfbd9dde1
and
ce2bc24996
to pick up the pip-installed libtpu. I'll update the README and Cloud
TPU VM documentation once these releases are out.
Revert also previous changes that pinned numpy to 1.19.
One of the changes in numpy 1.20 is to add more type annotations.
However, this sometimes make mypy give errors. A common example is
numpy.take, which with the new type annotation does not appear to
mypy as indexable.
Another change is that np.int and np.bool are deprecated. One
should use np.bool_ or np.int_, or the built-ins bool and int.
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
Updates XLA to 00afc7bb81.
The new XLA release removes the use of protocol buffers from the XLA client. Fixes#349.
Add backward compatibility shims to jaxlib to allow older jax releases to still work on an up to date jaxlib.
The new XLA release also incorporates a fix that avoids a host-device copy for every iteration of a `lax.fori_loop()` on GPU. Fixes#402.
Add a new jaxlib.__version__ field, change jax/jaxlib compatibility logic to check for it.