Dan Foreman-Mackey 21e98b5ce4 Fix overflow error in GPU batched linear algebra kernels.
As reported in https://github.com/jax-ml/jax/issues/24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695694648
2024-11-12 05:33:49 -08:00
..
2024-09-20 07:34:50 -07:00