mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 484351696
This commit is contained in:
parent
fc8f40ce0e
commit
66e75edd0b
@ -294,7 +294,7 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
|
||||
SparseConst beta = ConstZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
GPUSPARSE_SPMV_CSR_ALG, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
||||
@ -346,7 +346,7 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
|
||||
SparseConst beta = ConstZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_CSR_ALG, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
@ -467,7 +467,7 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
|
||||
SparseConst beta = ConstZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
GPUSPARSE_SPMV_COO_ALG, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
||||
@ -537,7 +537,7 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
|
||||
SparseConst beta = ConstZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_COO_ALG, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
|
@ -266,7 +266,7 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers,
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
d.y.type, GPUSPARSE_SPMV_CSR_ALG, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
||||
@ -324,7 +324,7 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_CSR_ALG, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
@ -463,7 +463,7 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers,
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
d.y.type, GPUSPARSE_SPMV_COO_ALG, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
||||
@ -529,7 +529,7 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers,
|
||||
/*batchStride=*/d.C.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_COO_ALG, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
|
@ -221,12 +221,28 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
#define GPUSPARSE_INDEX_64I CUSPARSE_INDEX_64I
|
||||
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT CUSPARSE_DENSETOSPARSE_ALG_DEFAULT
|
||||
#define GPUSPARSE_INDEX_BASE_ZERO CUSPARSE_INDEX_BASE_ZERO
|
||||
#define GPUSPARSE_MV_ALG_DEFAULT CUSPARSE_MV_ALG_DEFAULT
|
||||
// Use CUSPARSE_SPMV_COO_ALG2 and CUSPARSE_SPMV_CSR_ALG2 for SPMV and
|
||||
// use CUSPARSE_SPMM_COO_ALG2 and CUSPARSE_SPMM_CSR_ALG3 for SPMM, which
|
||||
// provide deterministic (bit-wise) results for each run.
|
||||
// CUSPARSE_SPMV_COO_ALG2 is available since cuda version 11.2.1
|
||||
// CUSPARSE_SPMV_CSR_ALG2 is available since cuda version 11.2.1
|
||||
// CUSPARSE_SPMM_COO_ALG2 is available since cuda version 11.0.3
|
||||
// CUSPARSE_SPMM_CSR_ALG3 is available since cuda version 11.2.1
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_SPMV_COO_ALG2
|
||||
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_SPMV_CSR_ALG2
|
||||
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_COO_ALG2
|
||||
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_CSR_ALG3
|
||||
#else
|
||||
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT
|
||||
#endif
|
||||
#define GPUSPARSE_OPERATION_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE
|
||||
#define GPUSPARSE_OPERATION_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE
|
||||
#define GPUSPARSE_ORDER_ROW CUSPARSE_ORDER_ROW
|
||||
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT CUSPARSE_SPARSETODENSE_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_ALG_DEFAULT CUSPARSE_SPMM_ALG_DEFAULT
|
||||
#define GPUSPARSE_STATUS_SUCCESS CUSPARSE_STATUS_SUCCESS
|
||||
|
||||
#define gpuGetLastError cudaGetLastError
|
||||
@ -418,13 +434,15 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
#define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I
|
||||
#define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I
|
||||
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT
|
||||
#define GPUSPARSE_MV_ALG_DEFAULT HIPSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_ALG_DEFAULT
|
||||
#define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO
|
||||
#define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE
|
||||
#define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE
|
||||
#define GPUSPARSE_ORDER_ROW HIPSPARSE_ORDER_ROW
|
||||
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_ALG_DEFAULT HIPSPARSE_SPMM_ALG_DEFAULT
|
||||
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
|
||||
|
||||
#define gpuGetLastError hipGetLastError
|
||||
|
@ -1208,7 +1208,11 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
# TODO(tianjianlu): In some cases, this fails python_should_be_executing.
|
||||
# self._CompileAndCheck(f_sparse, args_maker)
|
||||
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
|
||||
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
|
||||
if dtype is np.complex128:
|
||||
atol = 1E-1
|
||||
else:
|
||||
atol = 1E-2
|
||||
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker, atol=atol, rtol=1E-6)
|
||||
else:
|
||||
lhs_bcoo, lhs, rhs = args_maker()
|
||||
matmat_expected = f_dense(lhs_bcoo, lhs, rhs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user