Add experimental/sparse_ops & cusparse wrappers in jaxlib

PiperOrigin-RevId: 368663407
This commit is contained in:
Jake VanderPlas 2021-04-15 10:10:40 -07:00 committed by jax authors
parent 862df00965
commit 0d4bcde7ca
9 changed files with 1895 additions and 0 deletions

View File

@ -39,6 +39,7 @@ py_binary(
]) + if_cuda([
"//jaxlib:cublas_kernels",
"//jaxlib:cusolver_kernels",
"//jaxlib:cusparse_kernels",
"//jaxlib:cuda_lu_pivot_kernels",
"//jaxlib:cuda_prng_kernels",
]) + if_rocm([

View File

@ -193,6 +193,9 @@ def prepare_wheel(sources_path):
if r.Rlocation("__main__/jaxlib/rocblas_kernels.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocblas_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocsolver.py"))
if r.Rlocation("__main__/jaxlib/cusparse.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusparse.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusparse.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/version.py"))
if _is_windows():

View File

@ -67,6 +67,15 @@ pytype_library(
deps = [":jax"],
)
pytype_library(
name = "experimental_sparse_ops",
srcs = [
"experimental/sparse_ops.py",
],
srcs_version = "PY3",
deps = [":jax"],
)
pytype_library(
name = "optimizers",
srcs = ["experimental/optimizers.py"],

View File

@ -0,0 +1,442 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX primitives related to sparse operations.
This is experimental work to explore sparse support in JAX.
The primitives defined here are deliberately low-level: i.e. for now there is
no JAX CSR or COO matrix class. Each primitive implements a common sparse
operation (sparse to dense, dense to sparse, sparse matrix/vector product,
sparse matrix/matrix product) for two common sparse representations
(CSR and COO).
These routines have reference implementations defined via XLA scatter/gather
operations that will work on any backend, although they are not particularly
performant. On GPU runtimes with jaxlib 0.1.66 or newer built against CUDA 11.0
or newer, each operation is computed efficiently via cusparse.
"""
from jax import core
from jax.interpreters import xla
from jax.lib import cusparse
from jax.lib import xla_bridge
from jax.lib import xla_client
import jax.numpy as jnp
import numpy as np
xb = xla_bridge
xops = xla_client.ops
#--------------------------------------------------------------------
# csr_todense
csr_todense_p = core.Primitive('csr_todense')
def csr_todense(data, indices, indptr, *, shape):
"""Convert CSR-format sparse matrix to a dense matrix.
Args:
data : array of shape ``(nnz,)``.
indices : array of shape ``(nnz,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
shape : length-2 tuple representing the matrix shape
Returns:
mat : array with specified shape and dtype matching ``data``
"""
return csr_todense_p.bind(data, indices, indptr, shape=shape)
@csr_todense_p.def_impl
def _csr_todense_impl(data, indices, indptr, *, shape):
row = jnp.zeros_like(indices).at[indptr].add(1).cumsum() - 1
col = indices
return jnp.zeros(shape, data.dtype).at[row, col].add(data)
@csr_todense_p.def_abstract_eval
def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
assert data.ndim == indices.ndim == indptr.ndim == 1
assert indices.dtype == indptr.dtype
assert data.shape == indices.shape
assert indptr.shape[0] == shape[0] + 1
return core.ShapedArray(shape, data.dtype)
def _csr_todense_gpu_translation_rule(c, data, indices, indptr, *, shape):
return cusparse.csr_todense(c, data, indices, indptr, shape=shape)
xla.translations[csr_todense_p] = xla.lower_fun(
_csr_todense_impl, multiple_results=False)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
csr_todense_p] = _csr_todense_gpu_translation_rule
#--------------------------------------------------------------------
# csr_fromdense
csr_fromdense_p = core.Primitive('csr_fromdense')
csr_fromdense_p.multiple_results = True
def csr_fromdense(mat, *, nnz, index_dtype=np.int32):
"""Create CSR-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to CSR.
nnz : number of nonzero entries in ``mat``
index_dtype : dtype of sparse indices
Returns:
data : array of shape ``(nnz,)`` and dtype ``mat.dtype``.
indices : array of shape ``(nnz,)`` and dtype ``index_dtype``
indptr : array of shape ``(mat.shape[0] + 1,)`` and dtype ``index_dtype``
"""
return csr_fromdense_p.bind(
mat,
nnz=nnz,
index_dtype=np.dtype(index_dtype))
@csr_fromdense_p.def_impl
def _csr_fromdense_impl(mat, *, nnz, index_dtype):
mat = jnp.asarray(mat)
assert mat.ndim == 2
data = jnp.zeros(nnz, dtype=mat.dtype)
indices = jnp.zeros(nnz, dtype=index_dtype)
indptr = jnp.zeros(mat.shape[0] + 1, dtype=index_dtype)
mat_flat = jnp.ravel(mat)
ind = jnp.sort(jnp.argsort(-abs(mat_flat))[:nnz])
i, j = jnp.meshgrid(
jnp.arange(mat.shape[0]), jnp.arange(mat.shape[1]), indexing='ij')
row, col = jnp.ravel(i)[ind], jnp.ravel(j)[ind]
data = data.at[:mat.size].set(mat_flat[ind])
indices = indices.at[:mat.size].set(col)
indptr = indptr.at[1:].set(jnp.cumsum(jnp.bincount(row, length=mat.shape[0])))
return data, indices, indptr
@csr_fromdense_p.def_abstract_eval
def _csr_fromdense_abstract_eval(mat, *, nnz, index_dtype):
data = core.ShapedArray((nnz,), mat.dtype)
indices = core.ShapedArray((nnz,), index_dtype)
indptr = core.ShapedArray((mat.shape[0] + 1,), index_dtype)
return data, indices, indptr
def _csr_fromdense_gpu_translation_rule(c, mat, *, nnz, index_dtype):
data, indices, indptr = cusparse.csr_fromdense(
c, mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
return xops.Tuple(c, [data, indices, indptr])
xla.translations[csr_fromdense_p] = xla.lower_fun(
_csr_fromdense_impl, multiple_results=True)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
csr_fromdense_p] = _csr_fromdense_gpu_translation_rule
#--------------------------------------------------------------------
# csr_matvec
csr_matvec_p = core.Primitive('csr_matvec')
def csr_matvec(data, indices, indptr, v, *, shape, transpose=False):
"""Product of CSR sparse matrix and a dense vector.
Args:
data : array of shape ``(nnz,)``.
indices : array of shape ``(nnz,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
v : array of shape ``(shape[0] if transpose else shape[1],)``
and dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
return csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose)
@csr_matvec_p.def_impl
def _csr_matvec_impl(data, indices, indptr, v, *, shape, transpose):
v = jnp.asarray(v)
out_shape = shape[1] if transpose else shape[0]
row = jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1
col = indices
if transpose:
row, col = col, row
dv = data * v[col]
return jnp.zeros(out_shape, dv.dtype).at[row].add(dv)
@csr_matvec_p.def_abstract_eval
def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
assert len(shape) == 2
assert v.ndim == data.ndim == indices.ndim == indptr.ndim == 1
assert data.shape == indices.shape
assert data.dtype == v.dtype
assert indices.dtype == indptr.dtype
assert len(indptr) == shape[0] + 1
out_shape = shape[1] if transpose else shape[0]
assert v.shape == (shape[0],) if transpose else (shape[1],)
return core.ShapedArray((out_shape,), data.dtype)
def _csr_matvec_gpu_translation_rule(c, data, indices, indptr, v, *, shape, transpose):
return cusparse.csr_matvec(c, data, indices, indptr, v, shape=shape, transpose=transpose)
xla.translations[csr_matvec_p] = xla.lower_fun(
_csr_matvec_impl, multiple_results=False)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
csr_matvec_p] = _csr_matvec_gpu_translation_rule
#--------------------------------------------------------------------
# csr_matmat
csr_matmat_p = core.Primitive('csr_matmat')
def csr_matmat(data, indices, indptr, B, *, shape, transpose=False):
"""Product of CSR sparse matrix and a dense matrix.
Args:
data : array of shape ``(nnz,)``.
indices : array of shape ``(nnz,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix-matrix product product.
"""
return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
@csr_matmat_p.def_impl
def _csr_matmat_impl(data, indices, indptr, B, *, shape, transpose):
B = jnp.asarray(B)
out_shape = shape[1] if transpose else shape[0]
row = jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1
col = indices
if transpose:
row, col = col, row
dB = data[:, None] * B[col]
return jnp.zeros((out_shape, B.shape[1]), dB.dtype).at[row].add(dB)
@csr_matmat_p.def_abstract_eval
def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
assert data.ndim == indices.ndim == indptr.ndim == 1
assert B.ndim == 2
assert data.shape == indices.shape
assert data.dtype == B.dtype
assert indices.dtype == indptr.dtype
assert len(indptr) == shape[0] + 1
out_shape = shape[1] if transpose else shape[0]
assert B.shape[0] == shape[0] if transpose else shape[1]
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
def _csr_matmat_gpu_translation_rule(c, data, indices, indptr, B, *, shape, transpose):
return cusparse.csr_matmat(c, data, indices, indptr, B, shape=shape, transpose=transpose)
xla.translations[csr_matmat_p] = xla.lower_fun(
_csr_matmat_impl, multiple_results=False)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
csr_matmat_p] = _csr_matmat_gpu_translation_rule
#--------------------------------------------------------------------
# coo_todense
coo_todense_p = core.Primitive('coo_todense')
def coo_todense(data, row, col, *, shape):
"""Convert CSR-format sparse matrix to a dense matrix.
Args:
data : array of shape ``(nnz,)``.
row : array of shape ``(nnz,)``
col : array of shape ``(nnz,)`` and dtype ``row.dtype``
shape : length-2 tuple representing the matrix shape
Returns:
mat : array with specified shape and dtype matching ``data``
"""
return coo_todense_p.bind(data, row, col, shape=shape)
@coo_todense_p.def_impl
def _coo_todense_impl(data, row, col, *, shape):
return jnp.zeros(shape, data.dtype).at[row, col].set(data)
@coo_todense_p.def_abstract_eval
def _coo_todense_abstract_eval(data, row, col, *, shape):
return core.ShapedArray(shape, data.dtype)
def _coo_todense_gpu_translation_rule(c, data, row, col, *, shape):
return cusparse.coo_todense(c, data, row, col, shape=shape)
xla.translations[coo_todense_p] = xla.lower_fun(
_coo_todense_impl, multiple_results=False)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_todense_p] = _coo_todense_gpu_translation_rule
#--------------------------------------------------------------------
# coo_fromdense
coo_fromdense_p = core.Primitive('coo_fromdense')
coo_fromdense_p.multiple_results = True
def coo_fromdense(mat, *, nnz, index_dtype=jnp.int32):
"""Create COO-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to COO.
nnz : number of nonzero entries in ``mat``
index_dtype : dtype of sparse indices
Returns:
data : array of shape ``(nnz,)`` and dtype ``mat.dtype``
row : array of shape ``(nnz,)`` and dtype ``index_dtype``
col : array of shape ``(nnz,)`` and dtype ``index_dtype``
"""
return coo_fromdense_p.bind(mat, nnz=nnz, index_dtype=index_dtype)
@coo_fromdense_p.def_impl
def _coo_fromdense_impl(mat, *, nnz, index_dtype):
mat = jnp.asarray(mat)
m, n = mat.shape
ind = jnp.sort(jnp.argsort(abs(jnp.ravel(mat)))[m * n - nnz:]).astype(index_dtype)
return mat.ravel()[ind], ind // n, ind % n
@coo_fromdense_p.def_abstract_eval
def _coo_fromdense_abstract_eval(mat, *, nnz, index_dtype):
data = core.ShapedArray((nnz,), mat.dtype)
row = col = core.ShapedArray((nnz,), index_dtype)
return data, row, col
def _coo_fromdense_gpu_translation_rule(c, mat, *, nnz, index_dtype):
data, row, col = cusparse.coo_fromdense(
c, mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
return xops.Tuple(c, [data, row, col])
xla.translations[coo_fromdense_p] = xla.lower_fun(
_coo_fromdense_impl, multiple_results=True)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_fromdense_p] = _coo_fromdense_gpu_translation_rule
#--------------------------------------------------------------------
# coo_matvec
coo_matvec_p = core.Primitive('coo_matvec')
def coo_matvec(data, row, col, v, *, shape, transpose=False):
"""Product of COO sparse matrix and a dense vector.
Args:
data : array of shape ``(nnz,)``.
row : array of shape ``(nnz,)``
col : array of shape ``(nnz,)`` and dtype ``row.dtype``
v : array of shape ``(shape[0] if transpose else shape[1],)`` and
dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
return coo_matvec_p.bind(data, row, col, v, shape=shape, transpose=transpose)
@coo_matvec_p.def_impl
def _coo_matvec_impl(data, row, col, v, *, shape, transpose):
v = jnp.asarray(v)
if transpose:
row, col = col, row
out_shape = shape[1] if transpose else shape[0]
dv = data * v[col]
return jnp.zeros(out_shape, dv.dtype).at[row].add(dv)
@coo_matvec_p.def_abstract_eval
def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
assert data.shape == row.shape == col.shape
assert data.dtype == v.dtype
assert row.dtype == col.dtype
assert len(shape) == 2
assert v.shape == (shape[0],) if transpose else (shape[1],)
out_shape = shape[1] if transpose else shape[0]
return core.ShapedArray((out_shape,), data.dtype)
def _coo_matvec_gpu_translation_rule(c, data, row, col, v, *, shape, transpose):
return cusparse.coo_matvec(c, data, row, col, v, shape=shape, transpose=transpose)
xla.translations[coo_matvec_p] = xla.lower_fun(
_coo_matvec_impl, multiple_results=False)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_matvec_p] = _coo_matvec_gpu_translation_rule
#--------------------------------------------------------------------
# coo_matmat
coo_matmat_p = core.Primitive('coo_matmat')
def coo_matmat(data, row, col, B, *, shape, transpose=False):
"""Product of COO sparse matrix and a dense matrix.
Args:
data : array of shape ``(nnz,)``.
row : array of shape ``(nnz,)``
col : array of shape ``(nnz,)`` and dtype ``row.dtype``
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
transpose : boolean specifying whether to transpose the sparse matrix
before computing.
Returns:
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix vector product.
"""
return coo_matmat_p.bind(data, row, col, B, shape=shape, transpose=transpose)
@coo_matmat_p.def_impl
def _coo_matmat_impl(data, row, col, B, *, shape, transpose):
B = jnp.asarray(B)
if transpose:
row, col = col, row
out_shape = shape[1] if transpose else shape[0]
dB = data[:, None] * B[col]
return jnp.zeros((out_shape, B.shape[1]), dB.dtype).at[row].add(dB)
@coo_matmat_p.def_abstract_eval
def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
assert data.shape == row.shape == col.shape
assert data.dtype == B.dtype
assert len(shape) == 2
assert B.shape[0] == shape[0] if transpose else shape[1]
out_shape = shape[1] if transpose else shape[0]
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
def _coo_matmat_gpu_translation_rule(c, data, row, col, B, *, shape, transpose):
return cusparse.coo_matmat(c, data, row, col, B, shape=shape, transpose=transpose)
xla.translations[coo_matmat_p] = xla.lower_fun(
_coo_matmat_impl, multiple_results=False)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_matmat_p] = _coo_matmat_gpu_translation_rule

View File

@ -69,6 +69,11 @@ try:
except ImportError:
cusolver = None
try:
from jaxlib import cusparse # pytype: disable=import-error
except ImportError:
cusparse = None
try:
from jaxlib import rocsolver # pytype: disable=import-error
except ImportError:

View File

@ -119,6 +119,7 @@ py_library(
"cuda_linalg.py",
"cuda_prng.py",
"cusolver.py",
"cusparse.py",
]) + if_rocm_is_configured([
"rocsolver.py",
]),
@ -137,6 +138,7 @@ py_library(
":cuda_lu_pivot_kernels",
":cuda_prng_kernels",
":cusolver_kernels",
":cusparse_kernels",
],
)
@ -199,6 +201,36 @@ pybind_extension(
],
)
pybind_extension(
name = "cusparse_kernels",
srcs = ["cusparse.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "cusparse_kernels",
deps = [
":cuda_gpu_kernel_helpers",
":handle_pool",
":kernel_pybind11_helpers",
"//third_party/gpus/cuda:cusparse_static",
"@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub",
"@org_tensorflow//tensorflow/stream_executor/cuda:cusparse_lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@local_config_cuda//cuda:cuda_headers",
"@pybind11",
],
)
cuda_library(
name = "cuda_lu_pivot_kernels_lib",
srcs = ["cuda_lu_pivot_kernels.cu.cc"],

896
jaxlib/cusparse.cc Normal file
View File

@ -0,0 +1,896 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/gpus/cuda/includes/cuda_headers/third_party/gpus/cuda/include/cusparse.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/base/casts.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusparse.h"
#include "third_party/gpus/cuda/includes/cuda_headers/third_party/gpus/cuda/include/cuComplex.h"
#include "jaxlib/cuda_gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
// Some functionality defined here is only available in CUDA 11 or newer.
#define JAX_USE_CUDA11 (CUDART_VERSION >= 11000)
namespace jax {
namespace {
namespace py = pybind11;
void ThrowIfErrorStatus(cusparseStatus_t status) {
switch (status) {
case CUSPARSE_STATUS_SUCCESS:
return;
case CUSPARSE_STATUS_NOT_INITIALIZED:
throw std::runtime_error("cuSparse has not been initialized");
case CUSPARSE_STATUS_ALLOC_FAILED:
throw std::runtime_error("cuSparse allocation failure");
case CUSPARSE_STATUS_INVALID_VALUE:
throw std::runtime_error("cuSparse invalid value error");
case CUSPARSE_STATUS_ARCH_MISMATCH:
throw std::runtime_error("cuSparse architecture mismatch");
case CUSPARSE_STATUS_MAPPING_ERROR:
throw std::runtime_error("cuSparse mapping error");
case CUSPARSE_STATUS_EXECUTION_FAILED:
throw std::runtime_error("cuSparse execution failed");
case CUSPARSE_STATUS_INTERNAL_ERROR:
throw std::runtime_error("cuSparse internal error");
case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
throw std::runtime_error("cuSparse matrix type not supported error");
case CUSPARSE_STATUS_ZERO_PIVOT:
throw std::runtime_error("cuSparse zero pivot error");
default:
throw std::runtime_error("Unknown cuSparse error");
}
}
union CudaConst {
int8 i8[2];
int16 i16[2];
int32 i32[2];
int64 i64[2];
uint8 u8[2];
uint16 u16[2];
uint32 u32[2];
uint64 u64[2];
float f32[2];
double f64[2];
};
CudaConst CudaZero(cudaDataType type) {
CudaConst c;
std::memset(&c, 0, sizeof(c));
return c;
}
CudaConst CudaOne(cudaDataType type) {
CudaConst c;
std::memset(&c, 0, sizeof(c));
switch (type) {
#if JAX_USE_CUDA11
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
case CUDA_R_4I:
case CUDA_C_4I:
#endif
case CUDA_R_8I:
case CUDA_C_8I:
c.i8[0] = 1;
break;
#if JAX_USE_CUDA11
case CUDA_R_4U:
case CUDA_C_4U:
#endif
case CUDA_R_8U:
case CUDA_C_8U:
c.u8[0] = 1;
break;
#if JAX_USE_CUDA11
case CUDA_R_16I:
case CUDA_C_16I:
c.i16[0] = 1;
break;
case CUDA_R_16U:
case CUDA_C_16U:
c.u16[0] = 1;
break;
#endif
case CUDA_R_32I:
case CUDA_C_32I:
c.i32[0] = 1;
break;
case CUDA_R_32U:
case CUDA_C_32U:
c.u32[0] = 1;
break;
#if JAX_USE_CUDA11
case CUDA_R_64I:
case CUDA_C_64I:
c.i64[0] = 1;
break;
case CUDA_R_64U:
case CUDA_C_64U:
c.u64[0] = 1;
break;
#endif
// TODO(jakevdp): 16F/16BF here might break on big endian platforms.
case CUDA_R_16F:
case CUDA_C_16F:
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
break;
#if JAX_USE_CUDA11
case CUDA_R_16BF:
case CUDA_C_16BF:
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
break;
#endif
case CUDA_R_32F:
case CUDA_C_32F:
c.f32[0] = 1.0;
break;
case CUDA_R_64F:
case CUDA_C_64F:
c.f64[0] = 1.0;
break;
}
return c;
}
using SparseHandlePool = HandlePool<cusparseHandle_t, cudaStream_t>;
template <>
/*static*/ SparseHandlePool::Handle SparseHandlePool::Borrow(
cudaStream_t stream) {
SparseHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cusparseHandle_t handle;
if (pool->handles_.empty()) {
ThrowIfErrorStatus(cusparseCreate(&handle));
} else {
handle = pool->handles_.back();
pool->handles_.pop_back();
}
if (stream) {
ThrowIfErrorStatus(cusparseSetStream(handle, stream));
}
return Handle(pool, handle);
}
cusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, cusparseIndexType_t>({
{{'u', 2}, CUSPARSE_INDEX_16U},
{{'i', 4}, CUSPARSE_INDEX_32I},
{{'i', 8}, CUSPARSE_INDEX_64I},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported index dtype: %s", py::repr(np_type)));
}
return it->second;
}
cudaDataType DtypeToCudaDataType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, cudaDataType>({
{{'f', 2}, CUDA_R_16F}, {{'f', 4}, CUDA_R_32F}, {{'f', 4}, CUDA_R_32F},
{{'c', 8}, CUDA_C_32F}, {{'f', 8}, CUDA_R_64F},
{{'c', 16}, CUDA_C_64F}, {{'i', 1}, CUDA_R_8I},
{{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I},
{{'u', 4}, CUDA_R_32U},
#if JAX_USE_CUDA11
{{'V', 2}, CUDA_R_16BF},
#endif
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported data dtype: %s", py::repr(np_type)));
}
return it->second;
}
struct SparseMatDescriptor {
cudaDataType value_type;
cusparseIndexType_t index_type;
int rows, cols, nnz;
};
struct DenseMatDescriptor {
cudaDataType type;
int rows, cols;
};
struct DenseVecDescriptor {
cudaDataType type;
int size;
};
// Returns the descriptor for a Sparse matrix.
SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype,
const py::dtype& index_dtype,
int rows, int cols, int nnz) {
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
cusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype);
return SparseMatDescriptor{value_type, index_type, rows, cols, nnz};
}
// Returns the descriptor for a Dense matrix.
DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype,
int rows, int cols) {
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
return DenseMatDescriptor{value_type, rows, cols};
}
// Returns the descriptor for a Dense vector.
DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
int size) {
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
return DenseVecDescriptor{value_type, size};
}
#if JAX_USE_CUDA11
// CsrToDense: Convert CSR matrix to dense matrix
// Returns the descriptor for a Sparse matrix.
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto handle = SparseHandlePool::Borrow();
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
// buffer_size does not reference these pointers, but does error on NULL.
// TODO(jakevdp): check whether this is documented.
int val = 0;
void* empty = &val;
ThrowIfErrorStatus(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, empty,
empty, empty, d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type,
CUSPARSE_ORDER_ROW));
size_t buffer_size;
ThrowIfErrorStatus(cusparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
return {buffer_size, PackDescriptor(d)};
}
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SparseMatDescriptor& d =
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
ThrowIfErrorStatus(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type,
d.index_type, CUSPARSE_INDEX_BASE_ZERO,
d.value_type));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3],
d.value_type, CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseSparseToDense(handle.get(), mat_a, mat_b,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
buffers[4]));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
}
// CsrFromDense: Convert dense matrix to CSR matrix
// Returns the descriptor for a CsrFromDense operation.
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto handle = SparseHandlePool::Borrow();
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type,
CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, empty,
empty, empty, d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
size_t buffer_size;
ThrowIfErrorStatus(cusparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_a));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_b));
return {buffer_size, PackDescriptor(d)};
}
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SparseMatDescriptor& d =
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0],
d.value_type, CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type,
d.index_type, CUSPARSE_INDEX_BASE_ZERO,
d.value_type));
ThrowIfErrorStatus(cusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4]));
ThrowIfErrorStatus(cusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4]));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_a));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_b));
}
// CsrMatvec: Product of CSR matrix and dense vector.
struct CsrMatvecDescriptor {
SparseMatDescriptor A;
DenseVecDescriptor x, y;
cusparseOperation_t op;
};
// Returns the descriptor for a CsrMatvec operation.
std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto handle = SparseHandlePool::Borrow();
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseVecDescriptor x =
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
DenseVecDescriptor y =
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
ThrowIfErrorStatus(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty,
empty, empty, A.index_type, A.index_type,
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_x, x.size, empty, x.type));
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_y, y.size, empty, y.type));
size_t buffer_size;
CudaConst alpha = CudaOne(y.type);
CudaConst beta = CudaZero(y.type);
ThrowIfErrorStatus(cusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
CUSPARSE_MV_ALG_DEFAULT, &buffer_size));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_x));
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_y));
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
}
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const CsrMatvecDescriptor& d =
*UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
void* csr_values = buffers[0];
void* csr_col_ind = buffers[1];
void* csr_row_offsets = buffers[2];
void* xbuf = buffers[3];
void* ybuf = buffers[4];
void* buf = buffers[5];
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
CudaConst alpha = CudaOne(d.y.type);
CudaConst beta = CudaZero(d.y.type);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
ThrowIfErrorStatus(
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type));
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type));
ThrowIfErrorStatus(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x,
&beta, vec_y, d.y.type,
CUSPARSE_MV_ALG_DEFAULT, buf));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_x));
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_y));
}
// CsrMatmat: Product of CSR matrix and dense matrix.
struct CsrMatmatDescriptor {
SparseMatDescriptor A;
DenseMatDescriptor B, C;
cusparseOperation_t op_A;
};
// Returns the descriptor for a CsrMatmat operation.
std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
const py::dtype& data_dtype, const py::dtype& b_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose) {
auto handle = SparseHandlePool::Borrow();
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseMatDescriptor B =
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols);
DenseMatDescriptor C =
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols);
cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
ThrowIfErrorStatus(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty,
empty, empty, A.index_type, A.index_type,
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, CUSPARSE_ORDER_ROW));
size_t buffer_size;
CudaConst alpha = CudaOne(C.type);
CudaConst beta = CudaZero(C.type);
ThrowIfErrorStatus(cusparseSpMM_bufferSize(
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
}
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const CsrMatmatDescriptor& d =
*UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
void* csr_values = buffers[0];
void* csr_col_ind = buffers[1];
void* csr_row_offsets = buffers[2];
void* Bbuf = buffers[3];
void* Cbuf = buffers[4];
void* buf = buffers[5];
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
CudaConst alpha = CudaOne(d.C.type);
CudaConst beta = CudaZero(d.C.type);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
ThrowIfErrorStatus(
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type,
CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type,
CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseSpMM(
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
}
// CooToDense: Convert COO matrix to dense matrix
// Returns the descriptor for a CooToDense operation.
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto handle = SparseHandlePool::Borrow();
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
ThrowIfErrorStatus(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty,
empty, empty, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type,
CUSPARSE_ORDER_ROW));
size_t buffer_size;
ThrowIfErrorStatus(cusparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
return {buffer_size, PackDescriptor(d)};
}
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SparseMatDescriptor& d =
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
ThrowIfErrorStatus(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[1],
/*cooColInd=*/buffers[2],
/*cooValues=*/buffers[0], d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3],
d.value_type, CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseSparseToDense(handle.get(), mat_a, mat_b,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
buffers[4]));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
}
// CooFromDense: Convert dense matrix to COO matrix
// Returns the descriptor for a CooFromDense operation.
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto handle = SparseHandlePool::Borrow();
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type,
CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty,
empty, empty, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
size_t buffer_size;
ThrowIfErrorStatus(cusparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_a));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_b));
return {buffer_size, PackDescriptor(d)};
}
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SparseMatDescriptor& d =
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0],
d.value_type, CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[2],
/*cooColInd=*/buffers[3],
/*cooValues=*/buffers[1], d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
ThrowIfErrorStatus(cusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4]));
ThrowIfErrorStatus(cusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4]));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_a));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_b));
}
// CooMatvec: Product of COO matrix and dense vector.
struct CooMatvecDescriptor {
SparseMatDescriptor A;
DenseVecDescriptor x, y;
cusparseOperation_t op;
};
// Returns the descriptor for a CooMatvec operation.
std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto handle = SparseHandlePool::Borrow();
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseVecDescriptor x =
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
DenseVecDescriptor y =
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
ThrowIfErrorStatus(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty,
empty, empty, A.index_type,
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_x, x.size, empty, x.type));
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_y, y.size, empty, y.type));
size_t buffer_size;
CudaConst alpha = CudaOne(y.type);
CudaConst beta = CudaZero(y.type);
ThrowIfErrorStatus(cusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
CUSPARSE_MV_ALG_DEFAULT, &buffer_size));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_x));
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_y));
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
}
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const CooMatvecDescriptor& d =
*UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
void* coo_values = buffers[0];
void* coo_row_ind = buffers[1];
void* coo_col_ind = buffers[2];
void* xbuf = buffers[3];
void* ybuf = buffers[4];
void* buf = buffers[5];
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
CudaConst alpha = CudaOne(d.y.type);
CudaConst beta = CudaZero(d.y.type);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
ThrowIfErrorStatus(cusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type));
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type));
ThrowIfErrorStatus(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x,
&beta, vec_y, d.y.type,
CUSPARSE_MV_ALG_DEFAULT, buf));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_x));
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_y));
}
// CooMatmat: Product of COO matrix and dense matrix.
struct CooMatmatDescriptor {
SparseMatDescriptor A;
DenseMatDescriptor B, C;
cusparseOperation_t op_A;
};
// Returns the descriptor for a CooMatmat operation.
std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
const py::dtype& data_dtype, const py::dtype& b_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose) {
auto handle = SparseHandlePool::Borrow();
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseMatDescriptor B =
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols);
DenseMatDescriptor C =
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols);
cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
ThrowIfErrorStatus(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty,
empty, empty, A.index_type,
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, CUSPARSE_ORDER_ROW));
size_t buffer_size;
CudaConst alpha = CudaOne(C.type);
CudaConst beta = CudaZero(C.type);
ThrowIfErrorStatus(cusparseSpMM_bufferSize(
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
}
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const CooMatmatDescriptor& d =
*UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
void* coo_values = buffers[0];
void* coo_row_ind = buffers[1];
void* coo_col_ind = buffers[2];
void* Bbuf = buffers[3];
void* Cbuf = buffers[4];
void* buf = buffers[5];
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
CudaConst alpha = CudaOne(d.C.type);
CudaConst beta = CudaZero(d.C.type);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
ThrowIfErrorStatus(cusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type,
CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type,
CUSPARSE_ORDER_ROW));
ThrowIfErrorStatus(cusparseSpMM(
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf));
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
}
#endif
py::dict Registrations() {
py::dict dict;
#if JAX_USE_CUDA11
dict["cusparse_csr_todense"] = EncapsulateFunction(CsrToDense);
dict["cusparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
dict["cusparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
dict["cusparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
dict["cusparse_coo_todense"] = EncapsulateFunction(CooToDense);
dict["cusparse_coo_fromdense"] = EncapsulateFunction(CooFromDense);
dict["cusparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
dict["cusparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
#endif
return dict;
}
PYBIND11_MODULE(cusparse_kernels, m) {
m.attr("cusparse_supported") = py::bool_(JAX_USE_CUDA11);
m.def("registrations", &Registrations);
#if JAX_USE_CUDA11
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
m.def("build_csr_matmat_descriptor", &BuildCsrMatmatDescriptor);
m.def("build_coo_todense_descriptor", &BuildCooToDenseDescriptor);
m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor);
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
#endif
}
} // namespace
} // namespace jax

289
jaxlib/cusparse.py Normal file
View File

@ -0,0 +1,289 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
cusparse wrappers for performing sparse matrix computations in JAX
"""
import numpy as np
from jax.lib import xla_client
try:
from . import cusparse_kernels
except ImportError:
cusparse_kernels = None
else:
for _name, _value in cusparse_kernels.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
is_supported : bool = cusparse_kernels and cusparse_kernels.cusparse_supported
_ops = xla_client.ops
_Shape = xla_client.Shape
def csr_todense(c, data, indices, indptr, *, shape):
"""CSR to dense matrix."""
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
rows, cols = shape
nnz = c.get_shape(data).dimensions()[0]
buffer_size, opaque = cusparse_kernels.build_csr_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
out = xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_csr_todense",
operands=(data, indices, indptr),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, shape, (1, 0)),
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
)
return _ops.GetTupleElement(out, 0)
def csr_fromdense(c, mat, *, nnz, index_dtype):
"""CSR from dense matrix."""
data_dtype = np.dtype(c.get_shape(mat).element_type())
shape = c.get_shape(mat).dimensions()
rows, cols = shape
buffer_size, opaque = cusparse_kernels.build_csr_fromdense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
out = xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_csr_fromdense",
operands=(mat,),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (shape[0] + 1,), (0,)),
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
)
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
"""CSR matrix/vector multiply."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
x_dtype = np.dtype(c.get_shape(x).element_type())
rows, cols = shape
nnz, = c.get_shape(data).dimensions()
if compute_dtype is None:
compute_dtype = dtype
buffer_size, opaque = cusparse_kernels.build_csr_matvec_descriptor(
dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
out = xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_csr_matvec",
operands=(data, indices, indptr, x),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
c.get_shape(x),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
)
return _ops.GetTupleElement(out, 0)
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
"""CSR from dense matrix."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
B_dtype = np.dtype(c.get_shape(B).element_type())
rows, cols = shape
_, Ccols = c.get_shape(B).dimensions()
nnz, = c.get_shape(data).dimensions()
if compute_dtype is None:
compute_dtype = dtype
buffer_size, opaque = cusparse_kernels.build_csr_matmat_descriptor(
dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
out = xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_csr_matmat",
operands=(data, indices, indptr, B),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
c.get_shape(B),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
)
return _ops.GetTupleElement(out, 0)
def coo_todense(c, data, row, col, *, shape):
"""COO to dense matrix."""
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
rows, cols = shape
nnz = c.get_shape(data).dimensions()[0]
buffer_size, opaque = cusparse_kernels.build_coo_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
out = xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_coo_todense",
operands=(data, row, col),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, shape, (1, 0)),
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
)
return _ops.GetTupleElement(out, 0)
def coo_fromdense(c, mat, *, nnz, index_dtype):
"""COO from dense matrix."""
data_dtype = np.dtype(c.get_shape(mat).element_type())
shape = c.get_shape(mat).dimensions()
rows, cols = shape
buffer_size, opaque = cusparse_kernels.build_coo_fromdense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
out = xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_coo_fromdense",
operands=(mat,),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
)
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
"""CSR matrix/vector multiply."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
x_dtype = np.dtype(c.get_shape(x).element_type())
rows, cols = shape
nnz, = c.get_shape(data).dimensions()
if compute_dtype is None:
compute_dtype = dtype
buffer_size, opaque = cusparse_kernels.build_coo_matvec_descriptor(
dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
out = xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_coo_matvec",
operands=(data, row, col, x),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
c.get_shape(x),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
)
return _ops.GetTupleElement(out, 0)
def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
"""CSR from dense matrix."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
B_dtype = np.dtype(c.get_shape(B).element_type())
rows, cols = shape
_, Ccols = c.get_shape(B).dimensions()
nnz, = c.get_shape(data).dimensions()
if compute_dtype is None:
compute_dtype = dtype
buffer_size, opaque = cusparse_kernels.build_coo_matmat_descriptor(
dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
out = xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_coo_matmat",
operands=(data, row, col, B),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
c.get_shape(B),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
)
return _ops.GetTupleElement(out, 0)

218
tests/sparse_ops_test.py Normal file
View File

@ -0,0 +1,218 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from absl.testing import absltest
from absl.testing import parameterized
from jax import config
from jax.experimental import sparse_ops
from jax.lib import xla_bridge
from jax import jit
from jax import test_util as jtu
from jax import xla
import jax.numpy as jnp
import numpy as np
from scipy import sparse
config.parse_flags_with_absl()
FLAGS = config.FLAGS
def rand_sparse(rng, nnz=0.1, post=lambda x: x):
def _rand_sparse(shape, dtype, nnz=nnz):
rand = jtu.rand_default(rng)
size = np.prod(shape)
if 0 <= nnz < 1:
nnz = nnz * size
nnz = min(size, int(nnz))
M = rand(shape, dtype)
indices = rng.choice(size, size - nnz, replace=False)
M.flat[indices] = 0
return post(M)
return _rand_sparse
class cuSparseTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def test_csr_todense(self, shape, dtype):
rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
M = rng(shape, dtype)
args = (M.data, M.indices, M.indptr)
todense = lambda *args: sparse_ops.csr_todense(*args, shape=M.shape)
self.assertArraysEqual(M.toarray(), todense(*args))
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def test_csr_fromdense(self, shape, dtype):
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
M_csr = sparse.csr_matrix(M)
nnz = M_csr.nnz
index_dtype = jnp.int32
fromdense = lambda M: sparse_ops.csr_fromdense(M, nnz=nnz, index_dtype=jnp.int32)
data, indices, indptr = fromdense(M)
self.assertArraysEqual(data, M_csr.data.astype(dtype))
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
data, indices, indptr = jit(fromdense)(M)
self.assertArraysEqual(data, M_csr.data.astype(dtype))
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for transpose in [True, False]))
def test_csr_matvec(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
v_rng = jtu.rand_default(self.rng())
rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
M = rng(shape, dtype)
v = v_rng(op(M).shape[1], dtype)
args = (M.data, M.indices, M.indptr, v)
matvec = lambda *args: sparse_ops.csr_matvec(*args, shape=M.shape, transpose=transpose)
self.assertAllClose(op(M) @ v, matvec(*args))
self.assertAllClose(op(M) @ v, jit(matvec)(*args))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for transpose in [True, False]))
def test_csr_matmat(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
B_rng = jtu.rand_default(self.rng())
rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
M = rng(shape, dtype)
B = B_rng((op(M).shape[1], 4), dtype)
args = (M.data, M.indices, M.indptr, B)
matmat = lambda *args: sparse_ops.csr_matmat(*args, shape=shape, transpose=transpose)
self.assertAllClose(op(M) @ B, matmat(*args))
self.assertAllClose(op(M) @ B, jit(matmat)(*args))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def test_coo_todense(self, shape, dtype):
rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
M = rng(shape, dtype)
args = (M.data, M.row, M.col)
todense = lambda *args: sparse_ops.coo_todense(*args, shape=M.shape)
self.assertArraysEqual(M.toarray(), todense(*args))
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def test_coo_fromdense(self, shape, dtype):
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
M_coo = sparse.coo_matrix(M)
nnz = M_coo.nnz
index_dtype = jnp.int32
fromdense = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz, index_dtype=jnp.int32)
data, row, col = fromdense(M)
self.assertArraysEqual(data, M_coo.data.astype(dtype))
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
data, indices, indptr = jit(fromdense)(M)
self.assertArraysEqual(data, M_coo.data.astype(dtype))
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for transpose in [True, False]))
def test_coo_matvec(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
v_rng = jtu.rand_default(self.rng())
rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
M = rng(shape, dtype)
v = v_rng(op(M).shape[1], dtype)
args = (M.data, M.row, M.col, v)
matvec = lambda *args: sparse_ops.coo_matvec(*args, shape=M.shape, transpose=transpose)
self.assertAllClose(op(M) @ v, matvec(*args))
self.assertAllClose(op(M) @ v, jit(matvec)(*args))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for transpose in [True, False]))
def test_coo_matmat(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
B_rng = jtu.rand_default(self.rng())
rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
M = rng(shape, dtype)
B = B_rng((op(M).shape[1], 4), dtype)
args = (M.data, M.row, M.col, B)
matmat = lambda *args: sparse_ops.coo_matmat(*args, shape=shape, transpose=transpose)
self.assertAllClose(op(M) @ B, matmat(*args))
self.assertAllClose(op(M) @ B, jit(matmat)(*args))
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
def test_gpu_translation_rule(self):
version = xla_bridge.get_backend().platform_version
cuda_version = None if version == "<unknown>" else int(version.split()[-1])
if cuda_version is None or cuda_version < 11000:
self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
else:
self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())