Change in preparation for allowing JAX tests to run under Bazel.
Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.
PiperOrigin-RevId: 458522398
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.
Update symlink_files Bazel macro to a newer version.
PiperOrigin-RevId: 458481396
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.
However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.
PiperOrigin-RevId: 458041752
We now have an ml_program dialect that describes global variables
including load and store operations. Expose this dialect to allow
exporting variables and constants.
Previously we depended on various .so files directly so they were pulled into the jaxlib wheel build, but it seems to work to add the libraries in question to //jaxlib and depend on that in the usual way.
It appears if a py_library() is used as a data-dependency of another rule, Bazel includes any transitive C++ extension deps, and that's what we want.
PiperOrigin-RevId: 446802592
an initial prototype of an alternate JAX compilation path
that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO
together with sparse tensor types.
PiperOrigin-RevId: 443438043
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.
This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.
When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.
Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.
PiperOrigin-RevId: 394705605
Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.
The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.
There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.
PiperOrigin-RevId: 394460343
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.
Fixes https://github.com/google/jax/issues/6671
PiperOrigin-RevId: 374683190
libdevice.10.bc is a redistributable part of the CUDA SDK.
This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
1. Build on Windows
2. Fix OverflowError
When calling `key = random.PRNGKey(0)` OverflowError: Python int too
large to convert to C long for casting value 4294967295 (0xFFFFFFFF)
from python int to int32.
3. fix file path in regex of errors_test
4. handle ValueError of os.path.commonpath
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)
For the benchmark in #2952 on my workstation:
Before:
```
907.3490574884647
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 37.088446617126465
jax fft execution time [ms]: 74.93342399597168
```
After:
```
907.3490574884647
max: 1.9057386696477137e-12
mean: 3.9326737908882566e-13
min: 0.0
numpy fft execution time [ms]: 37.756404876708984
jax fft execution time [ms]: 28.128278255462646
```
Fixes https://github.com/google/jax/issues/2952
PiperOrigin-RevId: 338743753
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.
When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
* LU decomposition
* Symmetric (Hermitian) eigendecomposition
* Singular value decomposition.
Make LU decomposition tests less sensitive to the exact decomposition; check that we have a decomposition, not precisely the same one scipy returns.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.