15 Commits

Author SHA1 Message Date
Chris Jones
6b13d4eb86 Add branch prediction to JAX status macros.
PiperOrigin-RevId: 535233546
2023-05-25 06:23:23 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Peter Hawkins
5617a02fa4 Remove JAX custom call implementation of batched triangular solve.
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.

PiperOrigin-RevId: 482031830
2022-10-18 15:04:14 -07:00
Peter Hawkins
2693afa263 Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
2022-10-07 14:36:24 -07:00
Peter Hawkins
93b839ace4 Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

PiperOrigin-RevId: 479642543
2022-10-07 12:22:04 -07:00
Rohit Santhanam
b815ac9d8e [ROCm] Upgrade to ROCm 5.3 and associated enhancements 2022-10-01 04:45:26 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Rohit Santhanam
82adc6a1d0 [ROCm] Enhance hipsparse to reach parity with cusparse based on commit d37b711dd4. 2022-08-21 22:46:00 +00:00
Rohit Santhanam
080cf47002 [ROCm] Fixes for compilation failures caused by compiler changes in ROCm Tensorflow fork. 2022-06-29 14:34:08 +00:00
Peter Hawkins
aa7d291767 Replace references to absl::string_view with std::string_view.
PiperOrigin-RevId: 450768333
2022-05-24 14:21:32 -07:00
Rohit Santhanam
f9321f2536 Fix hipblas kernels for ROCm. 2022-05-19 14:47:21 +00:00
Peter Hawkins
bb0816227d Add a batched QR decomposition implementation on GPU.
PiperOrigin-RevId: 449583027
2022-05-18 14:50:18 -07:00
Rohit Santhanam
bbdcec84f8 Fixes to enable JAX to build on ROCm. 2022-05-06 22:57:51 +00:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00