A bug in CUDA prevents us from calling gpuStreamGetCtx inside graph capture. We use cuCtxGetCurrent as workaround for now.
PiperOrigin-RevId: 605417225
Autotuning is not compatible with graph capture because it requires synchronizing.
We use cuThreadExchangeStreamCaptureMode to execute a sequence of commands that are not recorded to graphs, similar to what NCCL does here: b6d7438d31/src/include/alloc.h (L171)
PiperOrigin-RevId: 602436960
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().
PiperOrigin-RevId: 582279549
This is intended to flag cases where the wrong CUDA libraries are used, either because:
* the user self-installed CUDA and that installation is too old, or
* the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version.
PiperOrigin-RevId: 568910422
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.
PiperOrigin-RevId: 560867049
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.
PiperOrigin-RevId: 483666051