mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster. PiperOrigin-RevId: 714069225