22 Commits

Author SHA1 Message Date
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Ruturaj4
79fccf6c82 add cholesky changes in bazel 2024-05-18 00:37:09 +00:00
Sergei Lebedev
51fc4f85ad Ported LuPivotsToPermutation to the typed XLA FFI
The typed FFI

* allows passing custom call attributes directly to backend_config= instead
  of serializing them into a C++ struct.
* It also handles validation and deserialization of custom call operands.

PiperOrigin-RevId: 630067005
2024-05-02 08:12:05 -07:00
Ruturaj4
97bf2d2bb8 [ROCm]: fix tsl path 2024-04-08 19:58:41 -05:00
Rahul Batra
b4b97cd8e8 [ROCm]: Add jax-triton support for ROCm 2023-10-18 07:09:20 +00:00
Rahul Batra
4091ac646c [ROCm]: Fix duplicate deps include 2023-09-08 22:56:59 +00:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Chris Jones
6b13d4eb86 Add branch prediction to JAX status macros.
PiperOrigin-RevId: 535233546
2023-05-25 06:23:23 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Peter Hawkins
5617a02fa4 Remove JAX custom call implementation of batched triangular solve.
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.

PiperOrigin-RevId: 482031830
2022-10-18 15:04:14 -07:00
Peter Hawkins
2693afa263 Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
2022-10-07 14:36:24 -07:00
Peter Hawkins
93b839ace4 Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

PiperOrigin-RevId: 479642543
2022-10-07 12:22:04 -07:00
Rohit Santhanam
b815ac9d8e [ROCm] Upgrade to ROCm 5.3 and associated enhancements 2022-10-01 04:45:26 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Rohit Santhanam
82adc6a1d0 [ROCm] Enhance hipsparse to reach parity with cusparse based on commit d37b711dd4. 2022-08-21 22:46:00 +00:00
Rohit Santhanam
080cf47002 [ROCm] Fixes for compilation failures caused by compiler changes in ROCm Tensorflow fork. 2022-06-29 14:34:08 +00:00
Peter Hawkins
aa7d291767 Replace references to absl::string_view with std::string_view.
PiperOrigin-RevId: 450768333
2022-05-24 14:21:32 -07:00
Rohit Santhanam
f9321f2536 Fix hipblas kernels for ROCm. 2022-05-19 14:47:21 +00:00
Peter Hawkins
bb0816227d Add a batched QR decomposition implementation on GPU.
PiperOrigin-RevId: 449583027
2022-05-18 14:50:18 -07:00
Rohit Santhanam
bbdcec84f8 Fixes to enable JAX to build on ROCm. 2022-05-06 22:57:51 +00:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00