To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.
Fixes https://github.com/google/jax/issues/6671
PiperOrigin-RevId: 374683190
libdevice.10.bc is a redistributable part of the CUDA SDK.
This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
The build_wheel.py script was copying the wrong module.
In addition the CUDA stubs from the TF repo were missing a number of cusparse symbols. The updated TF includes the correct stubs.
Update XLA.
CUDA 11.1 wheels are compatible with CUDA versions 11.1+, since NVidia now promises enhanced version compatibility between CUDA minor releases starting with CUDA 11.1
It's unnecessary because the image isn't used interactively in the
script, and it prevents the script from being used when no TTY is
available (e.g. when running from a different script).
This will make it easier to build a single wheel, e.g. for GPU CI testing.
TESTING=I manually ran both build_jaxlib_wheels.sh and the individual
helper functions. I didn't do a full release build, but I verified
that a complete nocuda wheel can be successfully built.
The type stubs allow using precise types for XLA primitives instead
of aliasing them to Any.
This commit does not change any type annotations within JAX. That will
be done in a followup. I have manually verified that type stubs are
discoverable by mypy once the new jaxlib is installed by type "checking"
from jaxlib import xla_extension as xe
d: xe._Dtype
Remove NCCL from CUDA installation scripts. This is partially because there are no Ubuntu 16.04 CUDA 11.2 NCCL packages, but also because we don't need NCCL packages in the first place since we are building from source.
Bump jaxlib version to 0.1.61 and update changelog.
Change jaxlib numpy version limit to >=1.16 for next release. Releases older than 1.16 are deprecated per NEP 00029. Reenable NumPy 1.20.
Bump minimum jaxlib version to 0.1.60.
Resolution order of paths to bazel binary is as follows.
1. Use --bazel_path command line option.
2. Search bazel binary in PATH environment variable.
3. Download required bazel release.