Srinivas Vasudevan
7dfc8ff49d
Add batching rules to jax.lax.linalg.tridiagonal_solve.
...
PiperOrigin-RevId: 555700103
2023-08-10 16:25:59 -07:00
Peter Hawkins
f168a1560c
[GPU] Add missing stream synchronization to tridiagonal_solve gtsv2 call.
...
May fix flaky failures in CI.
Make stream argument to Pool::Borrow() mandatory to minimize chance of forgetting it.
PiperOrigin-RevId: 530425766
2023-05-08 15:37:04 -07:00
Peter Hawkins
172a831219
Switch JAX to use the OpenXLA repository.
2023-03-13 18:38:26 +00:00
Tianjian Lu
ef0f64ec5c
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
...
PiperOrigin-RevId: 485441349
2022-11-01 16:01:50 -07:00
Jake VanderPlas
06c1d8efb5
Rollback of:
...
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
Still breaks CUDA 11.1
PiperOrigin-RevId: 485151807
2022-10-31 14:38:47 -07:00
Tianjian Lu
66e75edd0b
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
...
PiperOrigin-RevId: 484351696
2022-10-27 14:34:44 -07:00
Peter Hawkins
0814770601
Fix FP8 compilation failure in jaxlib stemming from the CUDA/ROCM merge.
...
PiperOrigin-RevId: 484026031
2022-10-26 11:40:14 -07:00
Peter Hawkins
a852710a09
Merge CUDA and ROCM kernel code in jaxlib.
...
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.
PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00