Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. ( #2117 )
...
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
832bb71c5d
Add missing BUILD dependency. ( #2089 )
2020-01-27 13:15:41 -05:00
Peter Hawkins
ffa198e8ef
Fix test failure on TPU. ( #2088 )
...
Update GUARDED_BY annotations to use newer ABSL_GUARDED_BY form.
2020-01-27 12:48:10 -05:00
Skye Wanderman-Milne
f04348ed53
Bump jaxlib version to 0.1.38 and update WORKSPACE.
2020-01-21 16:59:27 -08:00
AmKhan
dcda87d0e7
added batching to LAPACK triangular_solve ( #1985 )
...
* Added batching to cpu triangular_solver
* addressed comments about int overflows and returned triangular solve to use XLA over LAPACK
* add todo to benchmark LAPACK vs XLA
2020-01-14 11:18:47 -05:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. ( #1962 )
...
Remove six dependency.
2020-01-08 13:17:55 -05:00
Peter Hawkins
c5a9eba3a8
Implement batched cholesky decomposition using LAPACK/Cusolver ( #1956 )
...
* Implement batched Cholesky decomposition on CPU and GPU using LAPACK and cuSolver.
Adds support for complex batched Cholesky decomposition on both platforms..
Fix concurrency bug in batched cuBlas kernels where a host to device memcpy could take place too early before the device buffer was ready.
2020-01-07 10:56:15 -05:00
Peter Hawkins
94203bf022
Update XLA. ( #1837 )
...
Update jaxlib BUILD for ead06270dc
2019-12-10 11:25:09 -05:00
Skye Wanderman-Milne
7a154f71bc
Fix jaxlib build by not exposing nvcc to pybind11. ( #1819 )
2019-12-05 18:59:29 -08:00
Skye Wanderman-Milne
12a62c1f33
Bump jaxlib version to 0.1.37 and update WORKSPACE.
2019-12-03 12:29:34 -08:00
Matthew Johnson
b757949269
fix pulldown bugs
2019-11-26 17:06:57 -08:00
Peter Hawkins
34dfbc8ae6
Add error checking to PRNG CUDA kernel. ( #1760 )
...
Refactor error checking code into a common helper library.
2019-11-25 11:48:45 -05:00
Peter Hawkins
3b7d92db79
Add missing pybind11 dependency.
2019-11-24 14:17:18 -05:00
Peter Hawkins
d1aa01874d
Fix BUILD file formatting.
2019-11-24 13:13:39 -05:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. ( #1756 )
...
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.
When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
Skye Wanderman-Milne
1314fb7cb1
Bump jaxlib version to 0.1.36 and update WORKSPACE.
2019-11-21 18:34:28 -08:00
Skye Wanderman-Milne
b7d11ab90d
Bump jaxlib version to 0.1.35
2019-11-20 15:58:34 -08:00
Skye Wanderman-Milne
b1888881da
Bump jaxlib version to 0.1.33 and update WORKSPACE.
...
Includes XLA fixes for CPU psum.
2019-11-19 15:30:10 -08:00
Peter Hawkins
9679a87901
Avoid out-of-bounds dereference for arity-0 nodes. ( #1713 )
2019-11-18 15:35:07 -05:00
Skye Wanderman-Milne
84437839ed
Bump jaxlib version to 0.1.33 and update WORKSPACE.
2019-11-12 18:25:43 -08:00
Peter Hawkins
4b66d95782
Fix integer overflow for large matrices in linear algebra kernels. ( #1648 )
2019-11-08 12:49:07 -08:00
Peter Hawkins
affa2dcca4
Increment jax and jaxlib versions. ( #1603 )
...
* Update XLA version to 7acd3bb9d7
* Remove XRT reference from jaxlib build.
2019-10-30 15:50:00 -04:00
Peter Hawkins
1abf7cb2dd
Remove -Wno-c++98-c++11-compat directive from jaxlib BUILD file. ( #1544 )
...
We require C++14 now, so the directive is moot.
2019-10-21 11:41:28 -04:00
Skye Wanderman-Milne
d3fa506ed0
Bump jaxlib version to 0.1.31 and update WORKSPACE.
2019-10-08 09:38:39 -07:00
Skye Wanderman-Milne
2a8575bf04
Bump jaxlib version to 0.1.30 and update WORKSPACE.
2019-10-03 10:36:29 -07:00
Peter Hawkins
1428c11a2c
Update Jaxlib version to 0.1.29.
...
Bump XLA version. Enable C++14 mode since it is required by the new XLA version.
2019-09-28 15:11:09 -04:00
Skye Wanderman-Milne
796d369efa
Remove licenses() rule comment in BUILD files.
...
Internal tooling doesn't like it.
2019-09-26 14:54:07 -07:00
Peter Hawkins
c42444dc83
Fix compile error in cusolver.cc
2019-09-06 13:35:09 -04:00
Peter Hawkins
c0c4aac9ab
Implement batched singular value decomposition.
...
On GPU, switch to using the Jacobi implementation of SVD for matrices smaller than 32x32. The Jacobi implementation has an efficient implementation for batches of small matrices.
2019-09-05 18:12:00 -04:00
Skye Wanderman-Milne
7dc95f1f27
Bump jaxlib version to 0.1.28 and update WORKSPACE.
...
This pulls in breaking changes to the XLA client.
2019-09-04 18:15:15 -07:00
Peter Hawkins
02426b390c
Use LAPACK and Cusolver to implement QR decomposition on CPU/GPU.
...
This should be faster; also adds support for complex QR decompositions.
2019-09-04 16:24:32 -04:00
Matthew Johnson
009ff0745a
update jaxlib version for macOS wheels
...
cf. #1254
2019-08-27 11:44:16 -07:00
Skye Wanderman-Milne
4a22f56557
Bump jaxlib version to 0.1.26 and update WORKSPACE.
2019-08-24 10:25:34 -07:00
Matthew Johnson
e66582e877
restore the behavior that Nones are pytrees
2019-08-23 15:47:42 -07:00
Matthew Johnson
b702f8de3e
De-tuplify the rest of the core
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Skye Wanderman-Milne
4720776098
Update jaxlib version and XLA.
2019-08-19 19:12:40 -07:00
Peter Hawkins
2725c7e648
Update XLA.
2019-08-09 15:38:20 -04:00
Peter Hawkins
bc668e5638
Increment Jaxlib version.
...
Update XLA.
2019-08-08 17:08:46 -04:00
Peter Hawkins
dd10bdba8d
Remove newline from build file.
2019-08-08 16:33:50 -04:00
Peter Hawkins
233598a753
Add newline to build file.
2019-08-08 16:33:04 -04:00
Peter Hawkins
fef315b6e6
Add ability to pass extra bazel options to build script.
...
Remove cublas/cusolver dependencies from Jaxlib python code.
2019-08-08 16:14:45 -04:00
Peter Hawkins
5ac356d680
Add support for batched triangular solve and LU decomposition on GPU using cuBlas.
2019-08-08 13:34:53 -04:00
Peter Hawkins
72047c6eca
Update XLA.
2019-08-07 12:55:09 -04:00
Peter Hawkins
7160077cad
Use Jacobi algorithm for symmetric eigendecomposition for matrices with n < 32.
...
Use the batched Jacobi algorithm for large batches of small matrices.
2019-08-07 11:33:48 -04:00
Peter Hawkins
6bc476261b
More build formatting fixes.
2019-08-02 13:32:14 -04:00
Peter Hawkins
e0b31ac310
Build formatting fixes.
2019-08-02 13:29:52 -04:00
Peter Hawkins
ed3e2308c1
Add support for linear algebra ops on GPU using Cusolver:
...
* LU decomposition
* Symmetric (Hermitian) eigendecomposition
* Singular value decomposition.
Make LU decomposition tests less sensitive to the exact decomposition; check that we have a decomposition, not precisely the same one scipy returns.
2019-08-02 11:16:15 -04:00
Peter Hawkins
cb53ca876f
Address review comments.
2019-08-01 16:48:18 -04:00
Peter Hawkins
38bffe9a8b
Add a pytreedef.flatten_up_to() method that flattens a PyTree only up to the structure of a PyTreeDef.
...
Make the C++ version of tree_multimap accept tree suffixes of the primary tree. Document and test this behavior.
Remove unnecessary locking in custom node registry; we hold the GIL already so there's no point to the additional locking.
2019-08-01 12:17:00 -04:00
Peter Hawkins
e8b4236789
Address additional review comments.
2019-07-30 20:43:50 -04:00