9 Commits

Author SHA1 Message Date
Peter Hawkins
45b871950e Fix a number of minor problems in the ROCM build.
Change in preparation for adding more presubmits for AMD ROCM.

PiperOrigin-RevId: 667766343
2024-08-26 17:04:01 -07:00
Srinivas Vasudevan
7dfc8ff49d Add batching rules to jax.lax.linalg.tridiagonal_solve.
PiperOrigin-RevId: 555700103
2023-08-10 16:25:59 -07:00
Peter Hawkins
f168a1560c [GPU] Add missing stream synchronization to tridiagonal_solve gtsv2 call.
May fix flaky failures in CI.

Make stream argument to Pool::Borrow() mandatory to minimize chance of forgetting it.

PiperOrigin-RevId: 530425766
2023-05-08 15:37:04 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Tianjian Lu
ef0f64ec5c [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 485441349
2022-11-01 16:01:50 -07:00
Jake VanderPlas
06c1d8efb5 Rollback of:
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.

Still breaks CUDA 11.1

PiperOrigin-RevId: 485151807
2022-10-31 14:38:47 -07:00
Tianjian Lu
66e75edd0b [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 484351696
2022-10-27 14:34:44 -07:00
Peter Hawkins
0814770601 Fix FP8 compilation failure in jaxlib stemming from the CUDA/ROCM merge.
PiperOrigin-RevId: 484026031
2022-10-26 11:40:14 -07:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00