[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.

PiperOrigin-RevId: 484351696
This commit is contained in:
Tianjian Lu 2022-10-27 14:34:07 -07:00 committed by jax authors
parent fc8f40ce0e
commit 66e75edd0b
4 changed files with 35 additions and 13 deletions

View File

@ -294,7 +294,7 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
SparseConst beta = ConstZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
GPUSPARSE_SPMV_CSR_ALG, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
@ -346,7 +346,7 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
SparseConst beta = ConstZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_CSR_ALG, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
@ -467,7 +467,7 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
SparseConst beta = ConstZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
GPUSPARSE_SPMV_COO_ALG, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
@ -537,7 +537,7 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
SparseConst beta = ConstZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_COO_ALG, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));

View File

@ -266,7 +266,7 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
d.y.type, GPUSPARSE_SPMV_CSR_ALG, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
@ -324,7 +324,7 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers,
/*ld=*/d.C.cols, Cbuf, d.C.type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_CSR_ALG, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
@ -463,7 +463,7 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
d.y.type, GPUSPARSE_SPMV_COO_ALG, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
@ -529,7 +529,7 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers,
/*batchStride=*/d.C.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_COO_ALG, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));

View File

@ -221,12 +221,28 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSPARSE_INDEX_64I CUSPARSE_INDEX_64I
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT CUSPARSE_DENSETOSPARSE_ALG_DEFAULT
#define GPUSPARSE_INDEX_BASE_ZERO CUSPARSE_INDEX_BASE_ZERO
#define GPUSPARSE_MV_ALG_DEFAULT CUSPARSE_MV_ALG_DEFAULT
// Use CUSPARSE_SPMV_COO_ALG2 and CUSPARSE_SPMV_CSR_ALG2 for SPMV and
// use CUSPARSE_SPMM_COO_ALG2 and CUSPARSE_SPMM_CSR_ALG3 for SPMM, which
// provide deterministic (bit-wise) results for each run.
// CUSPARSE_SPMV_COO_ALG2 is available since cuda version 11.2.1
// CUSPARSE_SPMV_CSR_ALG2 is available since cuda version 11.2.1
// CUSPARSE_SPMM_COO_ALG2 is available since cuda version 11.0.3
// CUSPARSE_SPMM_CSR_ALG3 is available since cuda version 11.2.1
#if JAX_GPU_HAVE_SPARSE
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_SPMV_COO_ALG2
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_SPMV_CSR_ALG2
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_COO_ALG2
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_CSR_ALG3
#else
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT
#endif
#define GPUSPARSE_OPERATION_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE
#define GPUSPARSE_OPERATION_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE
#define GPUSPARSE_ORDER_ROW CUSPARSE_ORDER_ROW
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT CUSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_SPMM_ALG_DEFAULT CUSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS CUSPARSE_STATUS_SUCCESS
#define gpuGetLastError cudaGetLastError
@ -418,13 +434,15 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I
#define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT
#define GPUSPARSE_MV_ALG_DEFAULT HIPSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO
#define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE
#define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE
#define GPUSPARSE_ORDER_ROW HIPSPARSE_ORDER_ROW
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_SPMM_ALG_DEFAULT HIPSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
#define gpuGetLastError hipGetLastError

View File

@ -1208,7 +1208,11 @@ class BCOOTest(jtu.JaxTestCase):
# TODO(tianjianlu): In some cases, this fails python_should_be_executing.
# self._CompileAndCheck(f_sparse, args_maker)
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
if dtype is np.complex128:
atol = 1E-1
else:
atol = 1E-2
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker, atol=atol, rtol=1E-6)
else:
lhs_bcoo, lhs, rhs = args_maker()
matmat_expected = f_dense(lhs_bcoo, lhs, rhs)