Previously, the libtpu-nightly wheels were included in the same index
file as the jaxlib wheels (jax_releases.html). This caused issues
because it would cause `pip install jax[tpu] -f jaxlib_releases.html`
to install a cuda jaxlib, instead of the regular CPU/TPU jaxlib from
pypi.
Instead, we create a separate index file for the libtpu-nightly
wheels, so `pip install jax[tpu] -f libtpu_releases.html` still uses
the jaxlib from pypi.
This also renames generate_release_index.py to generate_release_indexes.py.
* Updates jax_releases.html index to include libtpu wheels
* Change [tpu] extras to specify `libtpu-nightly` instead of wheel URL
The full install command will now be:
`pip install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html`
(similar to the cuda install commands)
I've already pushed an updated jax_releases.html to the jax-releases GCS bucket.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.
Fixes https://github.com/google/jax/issues/6671
PiperOrigin-RevId: 374683190
libdevice.10.bc is a redistributable part of the CUDA SDK.
This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
The build_wheel.py script was copying the wrong module.
In addition the CUDA stubs from the TF repo were missing a number of cusparse symbols. The updated TF includes the correct stubs.
Update XLA.
CUDA 11.1 wheels are compatible with CUDA versions 11.1+, since NVidia now promises enhanced version compatibility between CUDA minor releases starting with CUDA 11.1
It's unnecessary because the image isn't used interactively in the
script, and it prevents the script from being used when no TTY is
available (e.g. when running from a different script).
This will make it easier to build a single wheel, e.g. for GPU CI testing.
TESTING=I manually ran both build_jaxlib_wheels.sh and the individual
helper functions. I didn't do a full release build, but I verified
that a complete nocuda wheel can be successfully built.
The type stubs allow using precise types for XLA primitives instead
of aliasing them to Any.
This commit does not change any type annotations within JAX. That will
be done in a followup. I have manually verified that type stubs are
discoverable by mypy once the new jaxlib is installed by type "checking"
from jaxlib import xla_extension as xe
d: xe._Dtype