49 Commits

Author SHA1 Message Date
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
Jake VanderPlas
919b11e81a Remove unnecessary dependency
PiperOrigin-RevId: 368882451
2021-04-16 11:10:33 -07:00
Jake VanderPlas
0d4bcde7ca Add experimental/sparse_ops & cusparse wrappers in jaxlib
PiperOrigin-RevId: 368663407
2021-04-15 10:11:10 -07:00
Andreas Hoenselaar
a19098d462 Reimplement as JAX Primitive 2021-04-03 14:11:36 -07:00
Peter Hawkins
9bdc2ecc66 Consolidate build macros into a single jax.bzl file.
PiperOrigin-RevId: 352871429
2021-01-20 14:06:22 -08:00
jax authors
8420ee200f Merge pull request #5472 from hawkinsp:build
PiperOrigin-RevId: 352797345
2021-01-20 08:21:02 -08:00
jax authors
d58563845b Merge pull request #5470 from inailuig:fix-rocm-build
PiperOrigin-RevId: 352783431
2021-01-20 06:46:44 -08:00
Peter Hawkins
3bec4b331d Fix ABSL build dependencies of //jaxlib:handle_pool 2021-01-20 09:38:55 -05:00
Clemens Giuliani
57405b0dc1 fix building on ROCm 2021-01-20 10:50:04 +01:00
jax authors
a8518769a2 Merge pull request #5115 from inailuig:rocm-gpukernels
PiperOrigin-RevId: 348077827
2020-12-17 14:01:04 -08:00
Clemens Giuliani
4981c53ac1 Add BLAS and LAPACK gpu kernels for ROCm 2020-12-16 16:00:17 +01:00
Clemens Giuliani
c128bdd90c extract the shared handle pool code from cublas and cusolver 2020-12-16 16:00:16 +01:00
Peter Hawkins
21914db261 Export LICENSE.txt file in jaxlib wheels. 2020-12-11 10:20:03 -05:00
Peter Hawkins
c06ead6b04 Change jaxlib build rules to build a wheel, rather than writing output to the source directory. 2020-11-20 11:47:00 -05:00
jax authors
9aa04443e9 Merge pull request #4965 from hawkinsp:build
PiperOrigin-RevId: 343333501
2020-11-19 11:32:00 -08:00
Peter Hawkins
4bb5dca779 Fix build.py to work on Linux once again.
* strip DOS end-of-line characters from build.py for consistency with the rest of the source tree.
* use shutil.copy() instead of shutil.copyfile(). On Unix systems we must preserve execute permissions.
* add code to explicitly delete and recreate the target directory.
* Move build/jaxlib/__init_py to jaxlib/__init__.py and have the script move it into position, so the output directory for the jaxlib is an empty directory that the script creates.
2020-11-19 13:14:34 -05:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Jean-Baptiste Lespiau
2ab6b42a45
Use pytree defined in tensorflow. (#4087)
It also adds some tests on the scalar C++ conversion.
2020-08-18 08:58:43 +03:00
George Necula
c7aff1da06
Revert "Use pytree from xla_client. (#4063)" (#4081)
This reverts commit d8de6b61411179dcd2f63d7639bbcd69b30ac15f.

Tryting to revert because it seems that this produces test
failures in Google.
2020-08-17 12:53:18 +03:00
Jean-Baptiste Lespiau
d8de6b6141
Use pytree from xla_client. (#4063) 2020-08-14 11:44:03 -04:00
Jean-Baptiste Lespiau
11afb3b62f
Extract pytree as it's own library. (#3909) 2020-08-03 11:31:34 -04:00
Peter Hawkins
a141cc6e8d
Make CUDA wheels manylinux2010 compliant, add CUDA 11, drop CUDA 9.2 (#3555)
* Use dynamic loading to locate CUDA libraries in jaxlib.

This should allow jaxlib CUDA wheels to be manylinux2010 compliant.

* Tag CUDA jaxlib wheels as manylinux2010.

Drop support for CUDA 9.2, add support for CUDA 11.0.

* Reorder CUDA imports.
2020-06-25 14:37:14 -04:00
Peter Hawkins
832bb71c5d
Add missing BUILD dependency. (#2089) 2020-01-27 13:15:41 -05:00
Peter Hawkins
c5a9eba3a8
Implement batched cholesky decomposition using LAPACK/Cusolver (#1956)
* Implement batched Cholesky decomposition on CPU and GPU using LAPACK and cuSolver.

Adds support for complex batched Cholesky decomposition on both platforms..
Fix concurrency bug in batched cuBlas kernels where a host to device memcpy could take place too early before the device buffer was ready.
2020-01-07 10:56:15 -05:00
Peter Hawkins
94203bf022
Update XLA. (#1837)
Update jaxlib BUILD for ead06270dc
2019-12-10 11:25:09 -05:00
Skye Wanderman-Milne
7a154f71bc
Fix jaxlib build by not exposing nvcc to pybind11. (#1819) 2019-12-05 18:59:29 -08:00
Matthew Johnson
b757949269 fix pulldown bugs 2019-11-26 17:06:57 -08:00
Peter Hawkins
34dfbc8ae6
Add error checking to PRNG CUDA kernel. (#1760)
Refactor error checking code into a common helper library.
2019-11-25 11:48:45 -05:00
Peter Hawkins
3b7d92db79 Add missing pybind11 dependency. 2019-11-24 14:17:18 -05:00
Peter Hawkins
d1aa01874d Fix BUILD file formatting. 2019-11-24 13:13:39 -05:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
Peter Hawkins
1abf7cb2dd
Remove -Wno-c++98-c++11-compat directive from jaxlib BUILD file. (#1544)
We require C++14 now, so the directive is moot.
2019-10-21 11:41:28 -04:00
Skye Wanderman-Milne
796d369efa Remove licenses() rule comment in BUILD files.
Internal tooling doesn't like it.
2019-09-26 14:54:07 -07:00
Peter Hawkins
2725c7e648 Update XLA. 2019-08-09 15:38:20 -04:00
Peter Hawkins
dd10bdba8d Remove newline from build file. 2019-08-08 16:33:50 -04:00
Peter Hawkins
233598a753 Add newline to build file. 2019-08-08 16:33:04 -04:00
Peter Hawkins
fef315b6e6 Add ability to pass extra bazel options to build script.
Remove cublas/cusolver dependencies from Jaxlib python code.
2019-08-08 16:14:45 -04:00
Peter Hawkins
5ac356d680 Add support for batched triangular solve and LU decomposition on GPU using cuBlas. 2019-08-08 13:34:53 -04:00
Peter Hawkins
72047c6eca Update XLA. 2019-08-07 12:55:09 -04:00
Peter Hawkins
6bc476261b More build formatting fixes. 2019-08-02 13:32:14 -04:00
Peter Hawkins
e0b31ac310 Build formatting fixes. 2019-08-02 13:29:52 -04:00
Peter Hawkins
ed3e2308c1 Add support for linear algebra ops on GPU using Cusolver:
* LU decomposition
* Symmetric (Hermitian) eigendecomposition
* Singular value decomposition.

Make LU decomposition tests less sensitive to the exact decomposition; check that we have a decomposition, not precisely the same one scipy returns.
2019-08-02 11:16:15 -04:00
Peter Hawkins
38bffe9a8b Add a pytreedef.flatten_up_to() method that flattens a PyTree only up to the structure of a PyTreeDef.
Make the C++ version of tree_multimap accept tree suffixes of the primary tree. Document and test this behavior.
Remove unnecessary locking in custom node registry; we hold the GIL already so there's no point to the additional locking.
2019-08-01 12:17:00 -04:00
Peter Hawkins
3c3f01e6d3 Address review comments. 2019-07-30 10:15:37 -04:00
Peter Hawkins
510a9167c5 Add C++ implementation of pytree logic.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
2019-07-29 15:06:05 -04:00
Peter Hawkins
5ceac99d0c Add newline to the end of jaxlib/BUILD. 2019-04-01 08:23:02 -07:00
Peter Hawkins
5eff830f0e Move jaxlib version.py into jaxlib, and install it in build/jaxlib as build action.
Update jaxlib version check to look in jaxlib.version.
2019-04-01 08:21:22 -07:00
Peter Hawkins
20935448b9 Whitespace fix to jaxlib/BUILD. 2018-12-19 14:51:46 -05:00
Peter Hawkins
3c388b98f1 Add support for calling LAPACK primitives from SciPy from JAX linalg. 2018-12-17 16:30:27 -05:00