mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00

The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything. PiperOrigin-RevId: 483666051
616 lines
25 KiB
C++
616 lines
25 KiB
C++
/* Copyright 2021 The JAX Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <stdexcept>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "absl/base/casts.h"
|
|
#include "absl/base/thread_annotations.h"
|
|
#include "absl/container/flat_hash_map.h"
|
|
#include "absl/strings/str_format.h"
|
|
#include "absl/synchronization/mutex.h"
|
|
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
|
#include "jaxlib/gpu/sparse_kernels.h"
|
|
#include "jaxlib/gpu/vendor.h"
|
|
#include "jaxlib/kernel_pybind11_helpers.h"
|
|
#include "include/pybind11/numpy.h"
|
|
#include "include/pybind11/pybind11.h"
|
|
#include "include/pybind11/stl.h"
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace jax {
|
|
namespace JAX_GPU_NAMESPACE {
|
|
namespace {
|
|
|
|
gpusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) {
|
|
static auto* types =
|
|
new absl::flat_hash_map<std::pair<char, int>, gpusparseIndexType_t>({
|
|
{{'u', 2}, GPUSPARSE_INDEX_16U},
|
|
{{'i', 4}, GPUSPARSE_INDEX_32I},
|
|
{{'i', 8}, GPUSPARSE_INDEX_64I},
|
|
});
|
|
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
|
if (it == types->end()) {
|
|
throw std::invalid_argument(
|
|
absl::StrFormat("Unsupported index dtype: %s", py::repr(np_type)));
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
gpuDataType DtypeToCudaDataType(const py::dtype& np_type) {
|
|
static auto* types =
|
|
new absl::flat_hash_map<std::pair<char, int>, gpuDataType>({
|
|
{{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F},
|
|
{{'c', 8}, GPU_C_32F}, {{'f', 8}, GPU_R_64F},
|
|
{{'c', 16}, GPU_C_64F},
|
|
#ifdef JAX_GPU_CUDA
|
|
{{'i', 1}, CUDA_R_8I}, {{'u', 1}, CUDA_R_8U},
|
|
{{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U},
|
|
#if JAX_GPU_HAVE_SPARSE
|
|
{{'V', 2}, CUDA_R_16BF},
|
|
#endif // JAX_GPU_HAVE_SPARSE
|
|
#endif // JAX_GPU_CUDA
|
|
});
|
|
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
|
if (it == types->end()) {
|
|
throw std::invalid_argument(
|
|
absl::StrFormat("Unsupported data dtype: %s", py::repr(np_type)));
|
|
}
|
|
return it->second;
|
|
}
|
|
// Returns the descriptor for a Sparse matrix.
|
|
SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype,
|
|
const py::dtype& index_dtype,
|
|
int rows, int cols, int nnz,
|
|
int batch_count,
|
|
int batch_stride) {
|
|
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
|
|
gpusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype);
|
|
return SparseMatDescriptor{value_type, index_type, rows, cols,
|
|
nnz, batch_count, batch_stride};
|
|
}
|
|
|
|
// Returns the descriptor for a Dense matrix.
|
|
DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype,
|
|
int rows, int cols, int batch_count,
|
|
int batch_stride) {
|
|
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
|
|
return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride};
|
|
}
|
|
|
|
// Returns the descriptor for a Dense vector.
|
|
DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
|
|
int size) {
|
|
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
|
|
return DenseVecDescriptor{value_type, size};
|
|
}
|
|
|
|
#if JAX_GPU_HAVE_SPARSE
|
|
// CsrToDense: Convert CSR matrix to dense matrix
|
|
|
|
// Returns the descriptor for a Sparse matrix.
|
|
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
|
|
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
|
int cols, int nnz) {
|
|
auto h = SparseHandlePool::Borrow();
|
|
JAX_THROW_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
SparseMatDescriptor d =
|
|
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
|
/*batch_count*/ 1, /*batch_stride*/ 0);
|
|
|
|
gpusparseSpMatDescr_t mat_a = 0;
|
|
gpusparseDnMatDescr_t mat_b = 0;
|
|
|
|
// buffer_size does not reference these pointers, but does error on NULL.
|
|
// TODO(jakevdp): check whether this is documented.
|
|
int val = 0;
|
|
void* empty = &val;
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
|
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
|
d.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
|
&mat_b, d.rows, d.cols,
|
|
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
|
|
size_t buffer_size;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSparseToDense_bufferSize(
|
|
handle.get(), mat_a, mat_b, GPUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
|
&buffer_size)));
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
|
|
|
return {buffer_size, PackDescriptor(d)};
|
|
}
|
|
|
|
absl::Status CsrToDense_(gpuStream_t stream, void** buffers, const char* opaque,
|
|
size_t opaque_len) {
|
|
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
|
JAX_RETURN_IF_ERROR(s.status());
|
|
const SparseMatDescriptor& d = **s;
|
|
auto h = SparseHandlePool::Borrow(stream);
|
|
JAX_RETURN_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
|
|
gpusparseSpMatDescr_t mat_a = 0;
|
|
gpusparseDnMatDescr_t mat_b = 0;
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
|
gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
|
/*csrRowOffsets=*/buffers[2],
|
|
/*csrColInd=*/buffers[1],
|
|
/*csrValues=*/buffers[0], d.index_type, d.index_type,
|
|
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
|
&mat_b, d.rows, d.cols,
|
|
/*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW)));
|
|
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
|
gpusparseSparseToDense(handle.get(), mat_a, mat_b,
|
|
GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
|
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
|
|
size_t opaque_len, XlaCustomCallStatus* status) {
|
|
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
|
|
if (!s.ok()) {
|
|
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
|
s.message().length());
|
|
}
|
|
}
|
|
|
|
// CsrFromDense: Convert dense matrix to CSR matrix
|
|
|
|
// Returns the descriptor for a CsrFromDense operation.
|
|
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
|
|
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
|
int cols, int nnz) {
|
|
auto h = SparseHandlePool::Borrow();
|
|
JAX_THROW_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
SparseMatDescriptor d =
|
|
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
|
/*batch_count=*/1, /*batch_stride=*/0);
|
|
|
|
gpusparseDnMatDescr_t mat_a = 0;
|
|
gpusparseSpMatDescr_t mat_b = 0;
|
|
|
|
// bufferSize does not reference these pointers, but does error on NULL.
|
|
int val = 0;
|
|
void* empty = &val;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
|
&mat_a, d.rows, d.cols,
|
|
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
|
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
|
d.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
|
size_t buffer_size;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_bufferSize(
|
|
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
|
&buffer_size)));
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
|
|
|
|
return {buffer_size, PackDescriptor(d)};
|
|
}
|
|
|
|
absl::Status CsrFromDense_(gpuStream_t stream, void** buffers,
|
|
const char* opaque, size_t opaque_len) {
|
|
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
|
JAX_RETURN_IF_ERROR(s.status());
|
|
const SparseMatDescriptor& d = **s;
|
|
auto h = SparseHandlePool::Borrow(stream);
|
|
JAX_RETURN_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
|
|
gpusparseDnMatDescr_t mat_a = 0;
|
|
gpusparseSpMatDescr_t mat_b = 0;
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
|
&mat_a, d.rows, d.cols,
|
|
/*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW)));
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
|
gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
|
/*csrRowOffsets=*/buffers[3],
|
|
/*csrColInd=*/buffers[2],
|
|
/*csrValues=*/buffers[1], d.index_type, d.index_type,
|
|
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis(
|
|
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
|
buffers[4])));
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert(
|
|
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
|
buffers[4])));
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
|
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque,
|
|
size_t opaque_len, XlaCustomCallStatus* status) {
|
|
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
|
|
if (!s.ok()) {
|
|
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
|
s.message().length());
|
|
}
|
|
}
|
|
|
|
// CsrMatvec: Product of CSR matrix and dense vector.
|
|
|
|
// Returns the descriptor for a CsrMatvec operation.
|
|
std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
|
|
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
|
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
|
int cols, int nnz, bool transpose) {
|
|
auto h = SparseHandlePool::Borrow();
|
|
JAX_THROW_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
SparseMatDescriptor A =
|
|
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
|
/*batch_count=*/1, /*batch_stride=*/0);
|
|
DenseVecDescriptor x =
|
|
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
|
|
DenseVecDescriptor y =
|
|
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
|
|
|
|
gpusparseSpMatDescr_t mat_a = 0;
|
|
gpusparseDnVecDescr_t vec_x = 0;
|
|
gpusparseDnVecDescr_t vec_y = 0;
|
|
gpusparseOperation_t op = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
|
|
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
|
|
|
|
// bufferSize does not reference these pointers, but does error on NULL.
|
|
int val = 0;
|
|
void* empty = &val;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
|
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
|
A.index_type, GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
|
JAX_THROW_IF_ERROR(
|
|
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
|
JAX_THROW_IF_ERROR(
|
|
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
|
size_t buffer_size;
|
|
SparseConst alpha = ConstOne(y.type);
|
|
SparseConst beta = ConstZero(y.type);
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
|
|
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
|
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
|
|
|
|
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
|
|
}
|
|
|
|
// CsrMatmat: Product of CSR matrix and dense matrix.
|
|
|
|
// Returns the descriptor for a CsrMatmat operation.
|
|
std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
|
|
const py::dtype& data_dtype, const py::dtype& b_dtype,
|
|
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
|
int cols, int BCcols, int nnz, bool transpose) {
|
|
auto h = SparseHandlePool::Borrow();
|
|
JAX_THROW_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
SparseMatDescriptor A =
|
|
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
|
/*batch_count=*/1, /*batch_stride=*/0);
|
|
DenseMatDescriptor B =
|
|
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols,
|
|
/*batch_count=*/1, /*batch_stride=*/0);
|
|
DenseMatDescriptor C =
|
|
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols,
|
|
/*batch_count=*/1, /*batch_stride=*/0);
|
|
gpusparseOperation_t op_A = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
|
|
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
|
|
|
|
gpusparseSpMatDescr_t mat_a = 0;
|
|
gpusparseDnMatDescr_t mat_b = 0;
|
|
gpusparseDnMatDescr_t mat_c = 0;
|
|
|
|
// bufferSize does not reference these pointers, but does error on NULL.
|
|
int val = 0;
|
|
void* empty = &val;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
|
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
|
A.index_type, GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
|
JAX_THROW_IF_ERROR(
|
|
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
|
empty, B.type, GPUSPARSE_ORDER_ROW)));
|
|
JAX_THROW_IF_ERROR(
|
|
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
|
empty, C.type, GPUSPARSE_ORDER_ROW)));
|
|
size_t buffer_size;
|
|
SparseConst alpha = ConstOne(C.type);
|
|
SparseConst beta = ConstZero(C.type);
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
|
|
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
|
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
|
|
|
|
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
|
|
}
|
|
|
|
// CooToDense: Convert COO matrix to dense matrix
|
|
|
|
// Returns the descriptor for a CooToDense operation.
|
|
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
|
|
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
|
int cols, int nnz) {
|
|
auto h = SparseHandlePool::Borrow();
|
|
JAX_THROW_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
SparseMatDescriptor d =
|
|
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
|
/*batch_count=*/1, /*batch_stride=*/0);
|
|
|
|
gpusparseSpMatDescr_t mat_a = 0;
|
|
gpusparseDnMatDescr_t mat_b = 0;
|
|
|
|
// bufferSize does not reference these pointers, but does error on NULL.
|
|
int val = 0;
|
|
void* empty = &val;
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
|
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
|
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
|
&mat_b, d.rows, d.cols,
|
|
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
|
|
size_t buffer_size;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSparseToDense_bufferSize(
|
|
handle.get(), mat_a, mat_b, GPUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
|
&buffer_size)));
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
|
|
|
return {buffer_size, PackDescriptor(d)};
|
|
}
|
|
|
|
// CooFromDense: Convert dense matrix to COO matrix
|
|
|
|
// Returns the descriptor for a CooFromDense operation.
|
|
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
|
|
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
|
int cols, int nnz) {
|
|
auto h = SparseHandlePool::Borrow();
|
|
JAX_THROW_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
SparseMatDescriptor d =
|
|
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
|
/*batch_count=*/1, /*batch_stride=*/0);
|
|
|
|
gpusparseDnMatDescr_t mat_a = 0;
|
|
gpusparseSpMatDescr_t mat_b = 0;
|
|
|
|
// bufferSize does not reference these pointers, but does error on NULL.
|
|
int val = 0;
|
|
void* empty = &val;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
|
&mat_a, d.rows, d.cols,
|
|
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
|
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
|
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
|
size_t buffer_size;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_bufferSize(
|
|
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
|
&buffer_size)));
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
|
|
|
|
return {buffer_size, PackDescriptor(d)};
|
|
}
|
|
|
|
// CooMatvec: Product of COO matrix and dense vector.
|
|
|
|
// Returns the descriptor for a CooMatvec operation.
|
|
std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
|
|
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
|
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
|
int cols, int nnz, bool transpose) {
|
|
auto h = SparseHandlePool::Borrow();
|
|
JAX_THROW_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
SparseMatDescriptor A =
|
|
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
|
/*batch_count=*/1, /*batch_stride=*/0);
|
|
DenseVecDescriptor x =
|
|
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
|
|
DenseVecDescriptor y =
|
|
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
|
|
|
|
gpusparseSpMatDescr_t mat_a = 0;
|
|
gpusparseDnVecDescr_t vec_x = 0;
|
|
gpusparseDnVecDescr_t vec_y = 0;
|
|
gpusparseOperation_t op = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
|
|
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
|
|
|
|
// bufferSize does not reference these pointers, but does error on NULL.
|
|
int val = 0;
|
|
void* empty = &val;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
|
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
|
GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
|
JAX_THROW_IF_ERROR(
|
|
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
|
JAX_THROW_IF_ERROR(
|
|
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
|
size_t buffer_size;
|
|
SparseConst alpha = ConstOne(y.type);
|
|
SparseConst beta = ConstZero(y.type);
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
|
|
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
|
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
|
|
|
|
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
|
|
}
|
|
|
|
// CooMatmat: Product of COO matrix and dense matrix.
|
|
|
|
// Returns the descriptor for a CooMatmat operation.
|
|
std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
|
|
const py::dtype& data_dtype, const py::dtype& b_dtype,
|
|
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
|
int cols, int BCcols, int nnz, bool transpose, int batch_count,
|
|
int lhs_batch_stride, int rhs_batch_stride) {
|
|
// Three batch modes are supported, C_i = A_i B, C_i = A B_i, and
|
|
// Ci = A_i B_i, where `i` denotes the batch dimension.
|
|
// All three matrices A, B, and C must have the same batch count.
|
|
// Use batch stride to trigger individual mode, e.g.,
|
|
// `rhs_batch_stride = 0` for C_i = A_i B.
|
|
auto h = SparseHandlePool::Borrow();
|
|
JAX_THROW_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
|
|
SparseMatDescriptor A = BuildSparseMatDescriptor(
|
|
data_dtype, index_dtype, rows, cols, nnz, batch_count, lhs_batch_stride);
|
|
DenseMatDescriptor B = BuildDenseMatDescriptor(
|
|
b_dtype, transpose ? rows : cols, BCcols, batch_count, rhs_batch_stride);
|
|
int C_rows = (transpose == true) ? cols : rows;
|
|
// TODO(tianjianlu): enable the selection of batch stride.
|
|
// The issue
|
|
// (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
|
|
// in cusparse library does not allow batch_stride = 0.
|
|
// int C_batch_stride = (batch_count > 1)? C_rows * BCcols : 0;
|
|
int C_batch_stride = C_rows * BCcols;
|
|
DenseMatDescriptor C =
|
|
BuildDenseMatDescriptor(compute_dtype, /*rows=*/C_rows, /*cols=*/BCcols,
|
|
batch_count, C_batch_stride);
|
|
gpusparseOperation_t op_A = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
|
|
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
|
|
|
|
gpusparseSpMatDescr_t mat_a = 0;
|
|
gpusparseDnMatDescr_t mat_b = 0;
|
|
gpusparseDnMatDescr_t mat_c = 0;
|
|
|
|
// bufferSize does not reference these pointers, but does error on NULL.
|
|
int val = 0;
|
|
void* empty = &val;
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
|
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
|
GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCooSetStridedBatch(
|
|
mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride)));
|
|
JAX_THROW_IF_ERROR(
|
|
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
|
empty, B.type, GPUSPARSE_ORDER_ROW)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch(
|
|
mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride)));
|
|
JAX_THROW_IF_ERROR(
|
|
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
|
empty, C.type, GPUSPARSE_ORDER_ROW)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch(
|
|
mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride)));
|
|
size_t buffer_size;
|
|
SparseConst alpha = ConstOne(C.type);
|
|
SparseConst beta = ConstZero(C.type);
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
|
|
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
|
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
|
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
|
|
|
|
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
|
|
}
|
|
|
|
#endif // if JAX_GPU_HAVE_SPARSE
|
|
|
|
py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
|
|
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
|
|
}
|
|
|
|
template <typename F>
|
|
size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
|
|
auto h = SparseHandlePool::Borrow();
|
|
JAX_THROW_IF_ERROR(h.status());
|
|
auto& handle = *h;
|
|
size_t size;
|
|
JAX_THROW_IF_ERROR(
|
|
JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr,
|
|
/*du=*/nullptr, /*B=*/nullptr, ldb, &size)));
|
|
return size;
|
|
}
|
|
|
|
size_t Gtsv2BufferSizeF32(int m, int n, int ldb) {
|
|
return Gtsv2BufferSize(gpusparseSgtsv2_bufferSizeExt, m, n, ldb);
|
|
}
|
|
|
|
size_t Gtsv2BufferSizeF64(int m, int n, int ldb) {
|
|
return Gtsv2BufferSize(gpusparseDgtsv2_bufferSizeExt, m, n, ldb);
|
|
}
|
|
|
|
py::dict Registrations() {
|
|
py::dict dict;
|
|
#if JAX_GPU_HAVE_SPARSE
|
|
dict[JAX_GPU_PREFIX "sparse_csr_todense"] = EncapsulateFunction(CsrToDense);
|
|
dict[JAX_GPU_PREFIX "sparse_csr_fromdense"] =
|
|
EncapsulateFunction(CsrFromDense);
|
|
dict[JAX_GPU_PREFIX "sparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
|
|
dict[JAX_GPU_PREFIX "sparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
|
|
dict[JAX_GPU_PREFIX "sparse_coo_todense"] = EncapsulateFunction(CooToDense);
|
|
dict[JAX_GPU_PREFIX "sparse_coo_fromdense"] =
|
|
EncapsulateFunction(CooFromDense);
|
|
dict[JAX_GPU_PREFIX "sparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
|
|
dict[JAX_GPU_PREFIX "sparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
|
|
#endif
|
|
dict[JAX_GPU_PREFIX "sparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
|
|
dict[JAX_GPU_PREFIX "sparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
|
|
// TODO(tomhennigan): Add support for gtsv2 complex 32/64.
|
|
return dict;
|
|
}
|
|
|
|
PYBIND11_MODULE(_sparse, m) {
|
|
m.attr("sparse_supported") = py::bool_(JAX_GPU_HAVE_SPARSE);
|
|
m.def("registrations", &Registrations);
|
|
#if JAX_GPU_HAVE_SPARSE
|
|
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
|
|
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
|
|
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
|
|
m.def("build_csr_matmat_descriptor", &BuildCsrMatmatDescriptor);
|
|
m.def("build_coo_todense_descriptor", &BuildCooToDenseDescriptor);
|
|
m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor);
|
|
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
|
|
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
|
|
#endif
|
|
m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32);
|
|
m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64);
|
|
m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace JAX_GPU_NAMESPACE
|
|
} // namespace jax
|