mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Use the default CSR matmul algorithm.
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always. PiperOrigin-RevId: 560867049
This commit is contained in:
parent
046bcc0ad9
commit
46ac9e2170
@ -58,6 +58,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jaxlib 0.4.15
|
||||
|
||||
* Changes:
|
||||
* Sparse CSR matrix multiplications via the experimental jax sparse APIs
|
||||
no longer uses a deterministic algorithm on NVIDIA GPUs. This change was
|
||||
made to improve compatibility with CUDA 12.2.1.
|
||||
|
||||
## jax 0.4.14 (July 27, 2023)
|
||||
|
||||
* Changes
|
||||
|
@ -226,7 +226,12 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_SPMV_COO_ALG2
|
||||
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_SPMV_CSR_ALG2
|
||||
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_COO_ALG2
|
||||
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_CSR_ALG3
|
||||
// In general Cusparse does not support a fully general deterministic CSR SpMM
|
||||
// algorithm.
|
||||
// In CUDA versions before 12.2.1, we used ALG3, which is deterministic, but
|
||||
// does not cover all cases and silently fell back to other algorithms for cases
|
||||
// it did not cover. CUDA 12.2.1 removed the fallback behavior.
|
||||
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT
|
||||
#else
|
||||
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_MV_ALG_DEFAULT
|
||||
|
Loading…
x
Reference in New Issue
Block a user