To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.
Fixes https://github.com/google/jax/issues/6671
PiperOrigin-RevId: 374683190
* strip DOS end-of-line characters from build.py for consistency with the rest of the source tree.
* use shutil.copy() instead of shutil.copyfile(). On Unix systems we must preserve execute permissions.
* add code to explicitly delete and recreate the target directory.
* Move build/jaxlib/__init_py to jaxlib/__init__.py and have the script move it into position, so the output directory for the jaxlib is an empty directory that the script creates.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)
For the benchmark in #2952 on my workstation:
Before:
```
907.3490574884647
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 37.088446617126465
jax fft execution time [ms]: 74.93342399597168
```
After:
```
907.3490574884647
max: 1.9057386696477137e-12
mean: 3.9326737908882566e-13
min: 0.0
numpy fft execution time [ms]: 37.756404876708984
jax fft execution time [ms]: 28.128278255462646
```
Fixes https://github.com/google/jax/issues/2952
PiperOrigin-RevId: 338743753
* Use dynamic loading to locate CUDA libraries in jaxlib.
This should allow jaxlib CUDA wheels to be manylinux2010 compliant.
* Tag CUDA jaxlib wheels as manylinux2010.
Drop support for CUDA 9.2, add support for CUDA 11.0.
* Reorder CUDA imports.
* Implement batched Cholesky decomposition on CPU and GPU using LAPACK and cuSolver.
Adds support for complex batched Cholesky decomposition on both platforms..
Fix concurrency bug in batched cuBlas kernels where a host to device memcpy could take place too early before the device buffer was ready.
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.
When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
* LU decomposition
* Symmetric (Hermitian) eigendecomposition
* Singular value decomposition.
Make LU decomposition tests less sensitive to the exact decomposition; check that we have a decomposition, not precisely the same one scipy returns.
Make the C++ version of tree_multimap accept tree suffixes of the primary tree. Document and test this behavior.
Remove unnecessary locking in custom node registry; we hold the GIL already so there's no point to the additional locking.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.