To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.
Fixes https://github.com/google/jax/issues/6671
PiperOrigin-RevId: 374683190
Require 1.12 or newer, because we've only tested 1.12 and 2.0.
Require less than 3.0, because flatbuffers uses semantic versioning and version 3.0 would mean an incompatible change has been made.
libdevice.10.bc is a redistributable part of the CUDA SDK.
This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
The type stubs allow using precise types for XLA primitives instead
of aliasing them to Any.
This commit does not change any type annotations within JAX. That will
be done in a followup. I have manually verified that type stubs are
discoverable by mypy once the new jaxlib is installed by type "checking"
from jaxlib import xla_extension as xe
d: xe._Dtype
Bump jaxlib version to 0.1.61 and update changelog.
Change jaxlib numpy version limit to >=1.16 for next release. Releases older than 1.16 are deprecated per NEP 00029. Reenable NumPy 1.20.
Bump minimum jaxlib version to 0.1.60.
* strip DOS end-of-line characters from build.py for consistency with the rest of the source tree.
* use shutil.copy() instead of shutil.copyfile(). On Unix systems we must preserve execute permissions.
* add code to explicitly delete and recreate the target directory.
* Move build/jaxlib/__init_py to jaxlib/__init__.py and have the script move it into position, so the output directory for the jaxlib is an empty directory that the script creates.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)
For the benchmark in #2952 on my workstation:
Before:
```
907.3490574884647
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 37.088446617126465
jax fft execution time [ms]: 74.93342399597168
```
After:
```
907.3490574884647
max: 1.9057386696477137e-12
mean: 3.9326737908882566e-13
min: 0.0
numpy fft execution time [ms]: 37.756404876708984
jax fft execution time [ms]: 28.128278255462646
```
Fixes https://github.com/google/jax/issues/2952
PiperOrigin-RevId: 338743753
--
138105a9ea44e7a8c3ce575a4e51b7ed51518d41 by Skye Wanderman-Milne <skyewm@google.com>:
Update README, CHANGELOG, and jaxlib.__version__ for new jaxlib release
PiperOrigin-RevId: 338063494
[JAX] Perform LAPACK workspace calculations in int64 to avoid overflows, clamp the values passed to lapack to int32.
Will fix https://github.com/google/jax/issues/4358 when incorporated into a jaxlib.
PiperOrigin-RevId: 337367394