17 Commits

Author SHA1 Message Date
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Tianjian Lu
e219d55c36 Roll-back #12892 because CUSPARSE_SPMV_COO_ALG2 is not available in CUDA 11.1
PiperOrigin-RevId: 482897448
2022-10-21 15:06:17 -07:00
Tianjian Lu
7093142f61 [sparse] Update the default cuSparse matvec algorithm in jaxlib.
PiperOrigin-RevId: 482553550
2022-10-20 11:49:09 -07:00
Peter Hawkins
5617a02fa4 Remove JAX custom call implementation of batched triangular solve.
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.

PiperOrigin-RevId: 482031830
2022-10-18 15:04:14 -07:00
Peter Hawkins
2693afa263 Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
2022-10-07 14:36:24 -07:00
Peter Hawkins
93b839ace4 Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

PiperOrigin-RevId: 479642543
2022-10-07 12:22:04 -07:00
Artem Belevich
2de91d26b7 Handle FP8 types.
PiperOrigin-RevId: 479148993
2022-10-05 14:48:30 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
David Dunleavy
a8aa774a57 Use tensorflow/compiler/xla/stream_executor instead of tensorflow/stream_executor
PiperOrigin-RevId: 470804752
2022-08-29 13:46:20 -07:00
Tianjian Lu
d37b711dd4 [sparse] Add batch count and batch stride to matrix descriptors.
PiperOrigin-RevId: 468760351
2022-08-19 12:26:17 -07:00
Deniz Oktay
d5de596d17 Sparse direct solver via QR factorization CUDA implementation.
PiperOrigin-RevId: 468467698
2022-08-18 08:46:25 -07:00
Peter Hawkins
3bb0030014 Revert: Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
Reason: Breaks JAX tests.
PiperOrigin-RevId: 468346430
2022-08-17 18:54:29 -07:00
Deniz Oktay
2bc3e39cd9 Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
PiperOrigin-RevId: 468303019
2022-08-17 15:10:27 -07:00
Tianjian Lu
07da502323 [sparse] Enable batch mode of COO matmat from cusparse kernels.
PiperOrigin-RevId: 465405490
2022-08-04 14:30:02 -07:00
Peter Hawkins
aa7d291767 Replace references to absl::string_view with std::string_view.
PiperOrigin-RevId: 450768333
2022-05-24 14:21:32 -07:00
Peter Hawkins
bb0816227d Add a batched QR decomposition implementation on GPU.
PiperOrigin-RevId: 449583027
2022-05-18 14:50:18 -07:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00