mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add experimental/sparse_ops & cusparse wrappers in jaxlib
PiperOrigin-RevId: 368663407
This commit is contained in:
parent
862df00965
commit
0d4bcde7ca
@ -39,6 +39,7 @@ py_binary(
|
||||
]) + if_cuda([
|
||||
"//jaxlib:cublas_kernels",
|
||||
"//jaxlib:cusolver_kernels",
|
||||
"//jaxlib:cusparse_kernels",
|
||||
"//jaxlib:cuda_lu_pivot_kernels",
|
||||
"//jaxlib:cuda_prng_kernels",
|
||||
]) + if_rocm([
|
||||
|
@ -193,6 +193,9 @@ def prepare_wheel(sources_path):
|
||||
if r.Rlocation("__main__/jaxlib/rocblas_kernels.so") is not None:
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocblas_kernels.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocsolver.py"))
|
||||
if r.Rlocation("__main__/jaxlib/cusparse.so") is not None:
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusparse.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusparse.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/version.py"))
|
||||
|
||||
if _is_windows():
|
||||
|
@ -67,6 +67,15 @@ pytype_library(
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "experimental_sparse_ops",
|
||||
srcs = [
|
||||
"experimental/sparse_ops.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "optimizers",
|
||||
srcs = ["experimental/optimizers.py"],
|
||||
|
442
jax/experimental/sparse_ops.py
Normal file
442
jax/experimental/sparse_ops.py
Normal file
@ -0,0 +1,442 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""JAX primitives related to sparse operations.
|
||||
|
||||
This is experimental work to explore sparse support in JAX.
|
||||
|
||||
The primitives defined here are deliberately low-level: i.e. for now there is
|
||||
no JAX CSR or COO matrix class. Each primitive implements a common sparse
|
||||
operation (sparse to dense, dense to sparse, sparse matrix/vector product,
|
||||
sparse matrix/matrix product) for two common sparse representations
|
||||
(CSR and COO).
|
||||
|
||||
These routines have reference implementations defined via XLA scatter/gather
|
||||
operations that will work on any backend, although they are not particularly
|
||||
performant. On GPU runtimes with jaxlib 0.1.66 or newer built against CUDA 11.0
|
||||
or newer, each operation is computed efficiently via cusparse.
|
||||
"""
|
||||
|
||||
from jax import core
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import cusparse
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import xla_client
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
xb = xla_bridge
|
||||
xops = xla_client.ops
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_todense
|
||||
|
||||
csr_todense_p = core.Primitive('csr_todense')
|
||||
|
||||
def csr_todense(data, indices, indptr, *, shape):
|
||||
"""Convert CSR-format sparse matrix to a dense matrix.
|
||||
|
||||
Args:
|
||||
data : array of shape ``(nnz,)``.
|
||||
indices : array of shape ``(nnz,)``
|
||||
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
||||
shape : length-2 tuple representing the matrix shape
|
||||
|
||||
Returns:
|
||||
mat : array with specified shape and dtype matching ``data``
|
||||
"""
|
||||
return csr_todense_p.bind(data, indices, indptr, shape=shape)
|
||||
|
||||
@csr_todense_p.def_impl
|
||||
def _csr_todense_impl(data, indices, indptr, *, shape):
|
||||
row = jnp.zeros_like(indices).at[indptr].add(1).cumsum() - 1
|
||||
col = indices
|
||||
return jnp.zeros(shape, data.dtype).at[row, col].add(data)
|
||||
|
||||
@csr_todense_p.def_abstract_eval
|
||||
def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
|
||||
assert data.ndim == indices.ndim == indptr.ndim == 1
|
||||
assert indices.dtype == indptr.dtype
|
||||
assert data.shape == indices.shape
|
||||
assert indptr.shape[0] == shape[0] + 1
|
||||
return core.ShapedArray(shape, data.dtype)
|
||||
|
||||
def _csr_todense_gpu_translation_rule(c, data, indices, indptr, *, shape):
|
||||
return cusparse.csr_todense(c, data, indices, indptr, shape=shape)
|
||||
|
||||
xla.translations[csr_todense_p] = xla.lower_fun(
|
||||
_csr_todense_impl, multiple_results=False)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
csr_todense_p] = _csr_todense_gpu_translation_rule
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_fromdense
|
||||
|
||||
csr_fromdense_p = core.Primitive('csr_fromdense')
|
||||
csr_fromdense_p.multiple_results = True
|
||||
|
||||
def csr_fromdense(mat, *, nnz, index_dtype=np.int32):
|
||||
"""Create CSR-format sparse matrix from a dense matrix.
|
||||
|
||||
Args:
|
||||
mat : array to be converted to CSR.
|
||||
nnz : number of nonzero entries in ``mat``
|
||||
index_dtype : dtype of sparse indices
|
||||
|
||||
Returns:
|
||||
data : array of shape ``(nnz,)`` and dtype ``mat.dtype``.
|
||||
indices : array of shape ``(nnz,)`` and dtype ``index_dtype``
|
||||
indptr : array of shape ``(mat.shape[0] + 1,)`` and dtype ``index_dtype``
|
||||
"""
|
||||
return csr_fromdense_p.bind(
|
||||
mat,
|
||||
nnz=nnz,
|
||||
index_dtype=np.dtype(index_dtype))
|
||||
|
||||
@csr_fromdense_p.def_impl
|
||||
def _csr_fromdense_impl(mat, *, nnz, index_dtype):
|
||||
mat = jnp.asarray(mat)
|
||||
assert mat.ndim == 2
|
||||
|
||||
data = jnp.zeros(nnz, dtype=mat.dtype)
|
||||
indices = jnp.zeros(nnz, dtype=index_dtype)
|
||||
indptr = jnp.zeros(mat.shape[0] + 1, dtype=index_dtype)
|
||||
|
||||
mat_flat = jnp.ravel(mat)
|
||||
ind = jnp.sort(jnp.argsort(-abs(mat_flat))[:nnz])
|
||||
i, j = jnp.meshgrid(
|
||||
jnp.arange(mat.shape[0]), jnp.arange(mat.shape[1]), indexing='ij')
|
||||
row, col = jnp.ravel(i)[ind], jnp.ravel(j)[ind]
|
||||
|
||||
data = data.at[:mat.size].set(mat_flat[ind])
|
||||
indices = indices.at[:mat.size].set(col)
|
||||
indptr = indptr.at[1:].set(jnp.cumsum(jnp.bincount(row, length=mat.shape[0])))
|
||||
return data, indices, indptr
|
||||
|
||||
@csr_fromdense_p.def_abstract_eval
|
||||
def _csr_fromdense_abstract_eval(mat, *, nnz, index_dtype):
|
||||
data = core.ShapedArray((nnz,), mat.dtype)
|
||||
indices = core.ShapedArray((nnz,), index_dtype)
|
||||
indptr = core.ShapedArray((mat.shape[0] + 1,), index_dtype)
|
||||
return data, indices, indptr
|
||||
|
||||
def _csr_fromdense_gpu_translation_rule(c, mat, *, nnz, index_dtype):
|
||||
data, indices, indptr = cusparse.csr_fromdense(
|
||||
c, mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
|
||||
return xops.Tuple(c, [data, indices, indptr])
|
||||
|
||||
xla.translations[csr_fromdense_p] = xla.lower_fun(
|
||||
_csr_fromdense_impl, multiple_results=True)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
csr_fromdense_p] = _csr_fromdense_gpu_translation_rule
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_matvec
|
||||
|
||||
csr_matvec_p = core.Primitive('csr_matvec')
|
||||
|
||||
def csr_matvec(data, indices, indptr, v, *, shape, transpose=False):
|
||||
"""Product of CSR sparse matrix and a dense vector.
|
||||
|
||||
Args:
|
||||
data : array of shape ``(nnz,)``.
|
||||
indices : array of shape ``(nnz,)``
|
||||
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
||||
v : array of shape ``(shape[0] if transpose else shape[1],)``
|
||||
and dtype ``data.dtype``
|
||||
shape : length-2 tuple representing the matrix shape
|
||||
transpose : boolean specifying whether to transpose the sparse matrix
|
||||
before computing.
|
||||
|
||||
Returns:
|
||||
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
||||
the matrix vector product.
|
||||
"""
|
||||
return csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose)
|
||||
|
||||
@csr_matvec_p.def_impl
|
||||
def _csr_matvec_impl(data, indices, indptr, v, *, shape, transpose):
|
||||
v = jnp.asarray(v)
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
row = jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1
|
||||
col = indices
|
||||
if transpose:
|
||||
row, col = col, row
|
||||
dv = data * v[col]
|
||||
return jnp.zeros(out_shape, dv.dtype).at[row].add(dv)
|
||||
|
||||
@csr_matvec_p.def_abstract_eval
|
||||
def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
|
||||
assert len(shape) == 2
|
||||
assert v.ndim == data.ndim == indices.ndim == indptr.ndim == 1
|
||||
assert data.shape == indices.shape
|
||||
assert data.dtype == v.dtype
|
||||
assert indices.dtype == indptr.dtype
|
||||
assert len(indptr) == shape[0] + 1
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
assert v.shape == (shape[0],) if transpose else (shape[1],)
|
||||
return core.ShapedArray((out_shape,), data.dtype)
|
||||
|
||||
def _csr_matvec_gpu_translation_rule(c, data, indices, indptr, v, *, shape, transpose):
|
||||
return cusparse.csr_matvec(c, data, indices, indptr, v, shape=shape, transpose=transpose)
|
||||
|
||||
xla.translations[csr_matvec_p] = xla.lower_fun(
|
||||
_csr_matvec_impl, multiple_results=False)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
csr_matvec_p] = _csr_matvec_gpu_translation_rule
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_matmat
|
||||
|
||||
csr_matmat_p = core.Primitive('csr_matmat')
|
||||
|
||||
def csr_matmat(data, indices, indptr, B, *, shape, transpose=False):
|
||||
"""Product of CSR sparse matrix and a dense matrix.
|
||||
|
||||
Args:
|
||||
data : array of shape ``(nnz,)``.
|
||||
indices : array of shape ``(nnz,)``
|
||||
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
||||
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
|
||||
dtype ``data.dtype``
|
||||
shape : length-2 tuple representing the matrix shape
|
||||
transpose : boolean specifying whether to transpose the sparse matrix
|
||||
before computing.
|
||||
|
||||
Returns:
|
||||
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
||||
representing the matrix-matrix product product.
|
||||
"""
|
||||
return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
|
||||
|
||||
@csr_matmat_p.def_impl
|
||||
def _csr_matmat_impl(data, indices, indptr, B, *, shape, transpose):
|
||||
B = jnp.asarray(B)
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
row = jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1
|
||||
col = indices
|
||||
if transpose:
|
||||
row, col = col, row
|
||||
dB = data[:, None] * B[col]
|
||||
return jnp.zeros((out_shape, B.shape[1]), dB.dtype).at[row].add(dB)
|
||||
|
||||
@csr_matmat_p.def_abstract_eval
|
||||
def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
|
||||
assert data.ndim == indices.ndim == indptr.ndim == 1
|
||||
assert B.ndim == 2
|
||||
assert data.shape == indices.shape
|
||||
assert data.dtype == B.dtype
|
||||
assert indices.dtype == indptr.dtype
|
||||
assert len(indptr) == shape[0] + 1
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
assert B.shape[0] == shape[0] if transpose else shape[1]
|
||||
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
||||
|
||||
def _csr_matmat_gpu_translation_rule(c, data, indices, indptr, B, *, shape, transpose):
|
||||
return cusparse.csr_matmat(c, data, indices, indptr, B, shape=shape, transpose=transpose)
|
||||
|
||||
xla.translations[csr_matmat_p] = xla.lower_fun(
|
||||
_csr_matmat_impl, multiple_results=False)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
csr_matmat_p] = _csr_matmat_gpu_translation_rule
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_todense
|
||||
|
||||
coo_todense_p = core.Primitive('coo_todense')
|
||||
|
||||
def coo_todense(data, row, col, *, shape):
|
||||
"""Convert CSR-format sparse matrix to a dense matrix.
|
||||
|
||||
Args:
|
||||
data : array of shape ``(nnz,)``.
|
||||
row : array of shape ``(nnz,)``
|
||||
col : array of shape ``(nnz,)`` and dtype ``row.dtype``
|
||||
shape : length-2 tuple representing the matrix shape
|
||||
|
||||
Returns:
|
||||
mat : array with specified shape and dtype matching ``data``
|
||||
"""
|
||||
return coo_todense_p.bind(data, row, col, shape=shape)
|
||||
|
||||
@coo_todense_p.def_impl
|
||||
def _coo_todense_impl(data, row, col, *, shape):
|
||||
return jnp.zeros(shape, data.dtype).at[row, col].set(data)
|
||||
|
||||
@coo_todense_p.def_abstract_eval
|
||||
def _coo_todense_abstract_eval(data, row, col, *, shape):
|
||||
return core.ShapedArray(shape, data.dtype)
|
||||
|
||||
def _coo_todense_gpu_translation_rule(c, data, row, col, *, shape):
|
||||
return cusparse.coo_todense(c, data, row, col, shape=shape)
|
||||
|
||||
xla.translations[coo_todense_p] = xla.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_todense_p] = _coo_todense_gpu_translation_rule
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_fromdense
|
||||
|
||||
coo_fromdense_p = core.Primitive('coo_fromdense')
|
||||
coo_fromdense_p.multiple_results = True
|
||||
|
||||
def coo_fromdense(mat, *, nnz, index_dtype=jnp.int32):
|
||||
"""Create COO-format sparse matrix from a dense matrix.
|
||||
|
||||
Args:
|
||||
mat : array to be converted to COO.
|
||||
nnz : number of nonzero entries in ``mat``
|
||||
index_dtype : dtype of sparse indices
|
||||
|
||||
Returns:
|
||||
data : array of shape ``(nnz,)`` and dtype ``mat.dtype``
|
||||
row : array of shape ``(nnz,)`` and dtype ``index_dtype``
|
||||
col : array of shape ``(nnz,)`` and dtype ``index_dtype``
|
||||
"""
|
||||
return coo_fromdense_p.bind(mat, nnz=nnz, index_dtype=index_dtype)
|
||||
|
||||
@coo_fromdense_p.def_impl
|
||||
def _coo_fromdense_impl(mat, *, nnz, index_dtype):
|
||||
mat = jnp.asarray(mat)
|
||||
m, n = mat.shape
|
||||
ind = jnp.sort(jnp.argsort(abs(jnp.ravel(mat)))[m * n - nnz:]).astype(index_dtype)
|
||||
return mat.ravel()[ind], ind // n, ind % n
|
||||
|
||||
@coo_fromdense_p.def_abstract_eval
|
||||
def _coo_fromdense_abstract_eval(mat, *, nnz, index_dtype):
|
||||
data = core.ShapedArray((nnz,), mat.dtype)
|
||||
row = col = core.ShapedArray((nnz,), index_dtype)
|
||||
return data, row, col
|
||||
|
||||
def _coo_fromdense_gpu_translation_rule(c, mat, *, nnz, index_dtype):
|
||||
data, row, col = cusparse.coo_fromdense(
|
||||
c, mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
|
||||
return xops.Tuple(c, [data, row, col])
|
||||
|
||||
xla.translations[coo_fromdense_p] = xla.lower_fun(
|
||||
_coo_fromdense_impl, multiple_results=True)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_fromdense_p] = _coo_fromdense_gpu_translation_rule
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matvec
|
||||
|
||||
coo_matvec_p = core.Primitive('coo_matvec')
|
||||
|
||||
def coo_matvec(data, row, col, v, *, shape, transpose=False):
|
||||
"""Product of COO sparse matrix and a dense vector.
|
||||
|
||||
Args:
|
||||
data : array of shape ``(nnz,)``.
|
||||
row : array of shape ``(nnz,)``
|
||||
col : array of shape ``(nnz,)`` and dtype ``row.dtype``
|
||||
v : array of shape ``(shape[0] if transpose else shape[1],)`` and
|
||||
dtype ``data.dtype``
|
||||
shape : length-2 tuple representing the matrix shape
|
||||
transpose : boolean specifying whether to transpose the sparse matrix
|
||||
before computing.
|
||||
|
||||
Returns:
|
||||
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
||||
the matrix vector product.
|
||||
"""
|
||||
return coo_matvec_p.bind(data, row, col, v, shape=shape, transpose=transpose)
|
||||
|
||||
@coo_matvec_p.def_impl
|
||||
def _coo_matvec_impl(data, row, col, v, *, shape, transpose):
|
||||
v = jnp.asarray(v)
|
||||
if transpose:
|
||||
row, col = col, row
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
dv = data * v[col]
|
||||
return jnp.zeros(out_shape, dv.dtype).at[row].add(dv)
|
||||
|
||||
@coo_matvec_p.def_abstract_eval
|
||||
def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
|
||||
assert data.shape == row.shape == col.shape
|
||||
assert data.dtype == v.dtype
|
||||
assert row.dtype == col.dtype
|
||||
assert len(shape) == 2
|
||||
assert v.shape == (shape[0],) if transpose else (shape[1],)
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
return core.ShapedArray((out_shape,), data.dtype)
|
||||
|
||||
def _coo_matvec_gpu_translation_rule(c, data, row, col, v, *, shape, transpose):
|
||||
return cusparse.coo_matvec(c, data, row, col, v, shape=shape, transpose=transpose)
|
||||
|
||||
xla.translations[coo_matvec_p] = xla.lower_fun(
|
||||
_coo_matvec_impl, multiple_results=False)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_matvec_p] = _coo_matvec_gpu_translation_rule
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matmat
|
||||
|
||||
coo_matmat_p = core.Primitive('coo_matmat')
|
||||
|
||||
def coo_matmat(data, row, col, B, *, shape, transpose=False):
|
||||
"""Product of COO sparse matrix and a dense matrix.
|
||||
|
||||
Args:
|
||||
data : array of shape ``(nnz,)``.
|
||||
row : array of shape ``(nnz,)``
|
||||
col : array of shape ``(nnz,)`` and dtype ``row.dtype``
|
||||
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
|
||||
dtype ``data.dtype``
|
||||
shape : length-2 tuple representing the matrix shape
|
||||
transpose : boolean specifying whether to transpose the sparse matrix
|
||||
before computing.
|
||||
|
||||
Returns:
|
||||
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
||||
representing the matrix vector product.
|
||||
"""
|
||||
return coo_matmat_p.bind(data, row, col, B, shape=shape, transpose=transpose)
|
||||
|
||||
@coo_matmat_p.def_impl
|
||||
def _coo_matmat_impl(data, row, col, B, *, shape, transpose):
|
||||
B = jnp.asarray(B)
|
||||
if transpose:
|
||||
row, col = col, row
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
dB = data[:, None] * B[col]
|
||||
return jnp.zeros((out_shape, B.shape[1]), dB.dtype).at[row].add(dB)
|
||||
|
||||
@coo_matmat_p.def_abstract_eval
|
||||
def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
|
||||
assert data.shape == row.shape == col.shape
|
||||
assert data.dtype == B.dtype
|
||||
assert len(shape) == 2
|
||||
assert B.shape[0] == shape[0] if transpose else shape[1]
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
||||
|
||||
def _coo_matmat_gpu_translation_rule(c, data, row, col, B, *, shape, transpose):
|
||||
return cusparse.coo_matmat(c, data, row, col, B, shape=shape, transpose=transpose)
|
||||
|
||||
xla.translations[coo_matmat_p] = xla.lower_fun(
|
||||
_coo_matmat_impl, multiple_results=False)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_matmat_p] = _coo_matmat_gpu_translation_rule
|
@ -69,6 +69,11 @@ try:
|
||||
except ImportError:
|
||||
cusolver = None
|
||||
|
||||
try:
|
||||
from jaxlib import cusparse # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cusparse = None
|
||||
|
||||
try:
|
||||
from jaxlib import rocsolver # pytype: disable=import-error
|
||||
except ImportError:
|
||||
|
32
jaxlib/BUILD
32
jaxlib/BUILD
@ -119,6 +119,7 @@ py_library(
|
||||
"cuda_linalg.py",
|
||||
"cuda_prng.py",
|
||||
"cusolver.py",
|
||||
"cusparse.py",
|
||||
]) + if_rocm_is_configured([
|
||||
"rocsolver.py",
|
||||
]),
|
||||
@ -137,6 +138,7 @@ py_library(
|
||||
":cuda_lu_pivot_kernels",
|
||||
":cuda_prng_kernels",
|
||||
":cusolver_kernels",
|
||||
":cusparse_kernels",
|
||||
],
|
||||
)
|
||||
|
||||
@ -199,6 +201,36 @@ pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "cusparse_kernels",
|
||||
srcs = ["cusparse.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "cusparse_kernels",
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":handle_pool",
|
||||
":kernel_pybind11_helpers",
|
||||
"//third_party/gpus/cuda:cusparse_static",
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub",
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_library(
|
||||
name = "cuda_lu_pivot_kernels_lib",
|
||||
srcs = ["cuda_lu_pivot_kernels.cu.cc"],
|
||||
|
896
jaxlib/cusparse.cc
Normal file
896
jaxlib/cusparse.cc
Normal file
@ -0,0 +1,896 @@
|
||||
/* Copyright 2021 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/gpus/cuda/includes/cuda_headers/third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "third_party/gpus/cuda/includes/cuda_headers/third_party/gpus/cuda/include/cuComplex.h"
|
||||
#include "jaxlib/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
|
||||
// Some functionality defined here is only available in CUDA 11 or newer.
|
||||
#define JAX_USE_CUDA11 (CUDART_VERSION >= 11000)
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void ThrowIfErrorStatus(cusparseStatus_t status) {
|
||||
switch (status) {
|
||||
case CUSPARSE_STATUS_SUCCESS:
|
||||
return;
|
||||
case CUSPARSE_STATUS_NOT_INITIALIZED:
|
||||
throw std::runtime_error("cuSparse has not been initialized");
|
||||
case CUSPARSE_STATUS_ALLOC_FAILED:
|
||||
throw std::runtime_error("cuSparse allocation failure");
|
||||
case CUSPARSE_STATUS_INVALID_VALUE:
|
||||
throw std::runtime_error("cuSparse invalid value error");
|
||||
case CUSPARSE_STATUS_ARCH_MISMATCH:
|
||||
throw std::runtime_error("cuSparse architecture mismatch");
|
||||
case CUSPARSE_STATUS_MAPPING_ERROR:
|
||||
throw std::runtime_error("cuSparse mapping error");
|
||||
case CUSPARSE_STATUS_EXECUTION_FAILED:
|
||||
throw std::runtime_error("cuSparse execution failed");
|
||||
case CUSPARSE_STATUS_INTERNAL_ERROR:
|
||||
throw std::runtime_error("cuSparse internal error");
|
||||
case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
|
||||
throw std::runtime_error("cuSparse matrix type not supported error");
|
||||
case CUSPARSE_STATUS_ZERO_PIVOT:
|
||||
throw std::runtime_error("cuSparse zero pivot error");
|
||||
default:
|
||||
throw std::runtime_error("Unknown cuSparse error");
|
||||
}
|
||||
}
|
||||
|
||||
union CudaConst {
|
||||
int8 i8[2];
|
||||
int16 i16[2];
|
||||
int32 i32[2];
|
||||
int64 i64[2];
|
||||
uint8 u8[2];
|
||||
uint16 u16[2];
|
||||
uint32 u32[2];
|
||||
uint64 u64[2];
|
||||
float f32[2];
|
||||
double f64[2];
|
||||
};
|
||||
|
||||
CudaConst CudaZero(cudaDataType type) {
|
||||
CudaConst c;
|
||||
std::memset(&c, 0, sizeof(c));
|
||||
return c;
|
||||
}
|
||||
|
||||
CudaConst CudaOne(cudaDataType type) {
|
||||
CudaConst c;
|
||||
std::memset(&c, 0, sizeof(c));
|
||||
switch (type) {
|
||||
#if JAX_USE_CUDA11
|
||||
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
|
||||
case CUDA_R_4I:
|
||||
case CUDA_C_4I:
|
||||
#endif
|
||||
case CUDA_R_8I:
|
||||
case CUDA_C_8I:
|
||||
c.i8[0] = 1;
|
||||
break;
|
||||
#if JAX_USE_CUDA11
|
||||
case CUDA_R_4U:
|
||||
case CUDA_C_4U:
|
||||
#endif
|
||||
case CUDA_R_8U:
|
||||
case CUDA_C_8U:
|
||||
c.u8[0] = 1;
|
||||
break;
|
||||
#if JAX_USE_CUDA11
|
||||
case CUDA_R_16I:
|
||||
case CUDA_C_16I:
|
||||
c.i16[0] = 1;
|
||||
break;
|
||||
case CUDA_R_16U:
|
||||
case CUDA_C_16U:
|
||||
c.u16[0] = 1;
|
||||
break;
|
||||
#endif
|
||||
case CUDA_R_32I:
|
||||
case CUDA_C_32I:
|
||||
c.i32[0] = 1;
|
||||
break;
|
||||
case CUDA_R_32U:
|
||||
case CUDA_C_32U:
|
||||
c.u32[0] = 1;
|
||||
break;
|
||||
#if JAX_USE_CUDA11
|
||||
case CUDA_R_64I:
|
||||
case CUDA_C_64I:
|
||||
c.i64[0] = 1;
|
||||
break;
|
||||
case CUDA_R_64U:
|
||||
case CUDA_C_64U:
|
||||
c.u64[0] = 1;
|
||||
break;
|
||||
#endif
|
||||
// TODO(jakevdp): 16F/16BF here might break on big endian platforms.
|
||||
case CUDA_R_16F:
|
||||
case CUDA_C_16F:
|
||||
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
|
||||
break;
|
||||
#if JAX_USE_CUDA11
|
||||
case CUDA_R_16BF:
|
||||
case CUDA_C_16BF:
|
||||
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
|
||||
break;
|
||||
#endif
|
||||
case CUDA_R_32F:
|
||||
case CUDA_C_32F:
|
||||
c.f32[0] = 1.0;
|
||||
break;
|
||||
case CUDA_R_64F:
|
||||
case CUDA_C_64F:
|
||||
c.f64[0] = 1.0;
|
||||
break;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
using SparseHandlePool = HandlePool<cusparseHandle_t, cudaStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ SparseHandlePool::Handle SparseHandlePool::Borrow(
|
||||
cudaStream_t stream) {
|
||||
SparseHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
cusparseHandle_t handle;
|
||||
if (pool->handles_.empty()) {
|
||||
ThrowIfErrorStatus(cusparseCreate(&handle));
|
||||
} else {
|
||||
handle = pool->handles_.back();
|
||||
pool->handles_.pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
ThrowIfErrorStatus(cusparseSetStream(handle, stream));
|
||||
}
|
||||
return Handle(pool, handle);
|
||||
}
|
||||
|
||||
cusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, cusparseIndexType_t>({
|
||||
{{'u', 2}, CUSPARSE_INDEX_16U},
|
||||
{{'i', 4}, CUSPARSE_INDEX_32I},
|
||||
{{'i', 8}, CUSPARSE_INDEX_64I},
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
throw std::invalid_argument(
|
||||
absl::StrFormat("Unsupported index dtype: %s", py::repr(np_type)));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
cudaDataType DtypeToCudaDataType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, cudaDataType>({
|
||||
{{'f', 2}, CUDA_R_16F}, {{'f', 4}, CUDA_R_32F}, {{'f', 4}, CUDA_R_32F},
|
||||
{{'c', 8}, CUDA_C_32F}, {{'f', 8}, CUDA_R_64F},
|
||||
{{'c', 16}, CUDA_C_64F}, {{'i', 1}, CUDA_R_8I},
|
||||
{{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I},
|
||||
{{'u', 4}, CUDA_R_32U},
|
||||
#if JAX_USE_CUDA11
|
||||
{{'V', 2}, CUDA_R_16BF},
|
||||
#endif
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
throw std::invalid_argument(
|
||||
absl::StrFormat("Unsupported data dtype: %s", py::repr(np_type)));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
struct SparseMatDescriptor {
|
||||
cudaDataType value_type;
|
||||
cusparseIndexType_t index_type;
|
||||
int rows, cols, nnz;
|
||||
};
|
||||
|
||||
struct DenseMatDescriptor {
|
||||
cudaDataType type;
|
||||
int rows, cols;
|
||||
};
|
||||
|
||||
struct DenseVecDescriptor {
|
||||
cudaDataType type;
|
||||
int size;
|
||||
};
|
||||
|
||||
// Returns the descriptor for a Sparse matrix.
|
||||
SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype,
|
||||
const py::dtype& index_dtype,
|
||||
int rows, int cols, int nnz) {
|
||||
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
cusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype);
|
||||
return SparseMatDescriptor{value_type, index_type, rows, cols, nnz};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a Dense matrix.
|
||||
DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype,
|
||||
int rows, int cols) {
|
||||
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
return DenseMatDescriptor{value_type, rows, cols};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a Dense vector.
|
||||
DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
|
||||
int size) {
|
||||
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
return DenseVecDescriptor{value_type, size};
|
||||
}
|
||||
|
||||
#if JAX_USE_CUDA11
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
// Returns the descriptor for a Sparse matrix.
|
||||
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
|
||||
// buffer_size does not reference these pointers, but does error on NULL.
|
||||
// TODO(jakevdp): check whether this is documented.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
|
||||
ThrowIfErrorStatus(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, empty,
|
||||
empty, empty, d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
size_t buffer_size;
|
||||
ThrowIfErrorStatus(cusparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
&buffer_size));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SparseMatDescriptor& d =
|
||||
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
ThrowIfErrorStatus(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[2],
|
||||
/*csrColInd=*/buffers[1],
|
||||
/*csrValues=*/buffers[0], d.index_type,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO,
|
||||
d.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3],
|
||||
d.value_type, CUSPARSE_ORDER_ROW));
|
||||
|
||||
ThrowIfErrorStatus(cusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
|
||||
}
|
||||
|
||||
// CsrFromDense: Convert dense matrix to CSR matrix
|
||||
|
||||
// Returns the descriptor for a CsrFromDense operation.
|
||||
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, empty,
|
||||
empty, empty, d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
size_t buffer_size;
|
||||
ThrowIfErrorStatus(cusparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
&buffer_size));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_b));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SparseMatDescriptor& d =
|
||||
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0],
|
||||
d.value_type, CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[3],
|
||||
/*csrColInd=*/buffers[2],
|
||||
/*csrValues=*/buffers[1], d.index_type,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO,
|
||||
d.value_type));
|
||||
ThrowIfErrorStatus(cusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
ThrowIfErrorStatus(cusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_b));
|
||||
}
|
||||
|
||||
// CsrMatvec: Product of CSR matrix and dense vector.
|
||||
|
||||
struct CsrMatvecDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseVecDescriptor x, y;
|
||||
cusparseOperation_t op;
|
||||
};
|
||||
|
||||
// Returns the descriptor for a CsrMatvec operation.
|
||||
std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
DenseVecDescriptor x =
|
||||
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
|
||||
DenseVecDescriptor y =
|
||||
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
ThrowIfErrorStatus(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty,
|
||||
empty, empty, A.index_type, A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_x, x.size, empty, x.type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_y, y.size, empty, y.type));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(y.type);
|
||||
CudaConst beta = CudaZero(y.type);
|
||||
ThrowIfErrorStatus(cusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, &buffer_size));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_x));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_y));
|
||||
|
||||
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
|
||||
}
|
||||
|
||||
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const CsrMatvecDescriptor& d =
|
||||
*UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
|
||||
void* csr_values = buffers[0];
|
||||
void* csr_col_ind = buffers[1];
|
||||
void* csr_row_offsets = buffers[2];
|
||||
void* xbuf = buffers[3];
|
||||
void* ybuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
CudaConst alpha = CudaOne(d.y.type);
|
||||
CudaConst beta = CudaZero(d.y.type);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
|
||||
ThrowIfErrorStatus(
|
||||
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
|
||||
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type));
|
||||
|
||||
ThrowIfErrorStatus(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x,
|
||||
&beta, vec_y, d.y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, buf));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_x));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_y));
|
||||
}
|
||||
|
||||
// CsrMatmat: Product of CSR matrix and dense matrix.
|
||||
|
||||
struct CsrMatmatDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseMatDescriptor B, C;
|
||||
cusparseOperation_t op_A;
|
||||
};
|
||||
|
||||
// Returns the descriptor for a CsrMatmat operation.
|
||||
std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& b_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int BCcols, int nnz, bool transpose) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
DenseMatDescriptor B =
|
||||
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols);
|
||||
DenseMatDescriptor C =
|
||||
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols);
|
||||
cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
ThrowIfErrorStatus(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty,
|
||||
empty, empty, A.index_type, A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, CUSPARSE_ORDER_ROW));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(C.type);
|
||||
CudaConst beta = CudaZero(C.type);
|
||||
ThrowIfErrorStatus(cusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
|
||||
|
||||
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
|
||||
}
|
||||
|
||||
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const CsrMatmatDescriptor& d =
|
||||
*UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
|
||||
void* csr_values = buffers[0];
|
||||
void* csr_col_ind = buffers[1];
|
||||
void* csr_row_offsets = buffers[2];
|
||||
void* Bbuf = buffers[3];
|
||||
void* Cbuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
CudaConst alpha = CudaOne(d.C.type);
|
||||
CudaConst beta = CudaZero(d.C.type);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
ThrowIfErrorStatus(
|
||||
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
|
||||
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
|
||||
}
|
||||
|
||||
// CooToDense: Convert COO matrix to dense matrix
|
||||
|
||||
// Returns the descriptor for a CooToDense operation.
|
||||
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
|
||||
ThrowIfErrorStatus(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty,
|
||||
empty, empty, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
size_t buffer_size;
|
||||
ThrowIfErrorStatus(cusparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
&buffer_size));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SparseMatDescriptor& d =
|
||||
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
ThrowIfErrorStatus(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[1],
|
||||
/*cooColInd=*/buffers[2],
|
||||
/*cooValues=*/buffers[0], d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3],
|
||||
d.value_type, CUSPARSE_ORDER_ROW));
|
||||
|
||||
ThrowIfErrorStatus(cusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
|
||||
}
|
||||
|
||||
// CooFromDense: Convert dense matrix to COO matrix
|
||||
|
||||
// Returns the descriptor for a CooFromDense operation.
|
||||
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty,
|
||||
empty, empty, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
size_t buffer_size;
|
||||
ThrowIfErrorStatus(cusparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
&buffer_size));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_b));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SparseMatDescriptor& d =
|
||||
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0],
|
||||
d.value_type, CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[2],
|
||||
/*cooColInd=*/buffers[3],
|
||||
/*cooValues=*/buffers[1], d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
ThrowIfErrorStatus(cusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
ThrowIfErrorStatus(cusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_b));
|
||||
}
|
||||
|
||||
// CooMatvec: Product of COO matrix and dense vector.
|
||||
|
||||
struct CooMatvecDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseVecDescriptor x, y;
|
||||
cusparseOperation_t op;
|
||||
};
|
||||
|
||||
// Returns the descriptor for a CooMatvec operation.
|
||||
std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
DenseVecDescriptor x =
|
||||
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
|
||||
DenseVecDescriptor y =
|
||||
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
ThrowIfErrorStatus(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty,
|
||||
empty, empty, A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_x, x.size, empty, x.type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_y, y.size, empty, y.type));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(y.type);
|
||||
CudaConst beta = CudaZero(y.type);
|
||||
ThrowIfErrorStatus(cusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, &buffer_size));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_x));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_y));
|
||||
|
||||
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
|
||||
}
|
||||
|
||||
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const CooMatvecDescriptor& d =
|
||||
*UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
|
||||
void* coo_values = buffers[0];
|
||||
void* coo_row_ind = buffers[1];
|
||||
void* coo_col_ind = buffers[2];
|
||||
void* xbuf = buffers[3];
|
||||
void* ybuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
CudaConst alpha = CudaOne(d.y.type);
|
||||
CudaConst beta = CudaZero(d.y.type);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
|
||||
ThrowIfErrorStatus(cusparseCreateCoo(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
|
||||
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type));
|
||||
|
||||
ThrowIfErrorStatus(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x,
|
||||
&beta, vec_y, d.y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, buf));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_x));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnVec(vec_y));
|
||||
}
|
||||
|
||||
// CooMatmat: Product of COO matrix and dense matrix.
|
||||
|
||||
struct CooMatmatDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseMatDescriptor B, C;
|
||||
cusparseOperation_t op_A;
|
||||
};
|
||||
|
||||
// Returns the descriptor for a CooMatmat operation.
|
||||
std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& b_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int BCcols, int nnz, bool transpose) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
DenseMatDescriptor B =
|
||||
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols);
|
||||
DenseMatDescriptor C =
|
||||
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols);
|
||||
cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
ThrowIfErrorStatus(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty,
|
||||
empty, empty, A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, CUSPARSE_ORDER_ROW));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(C.type);
|
||||
CudaConst beta = CudaZero(C.type);
|
||||
ThrowIfErrorStatus(cusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
|
||||
|
||||
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
|
||||
}
|
||||
|
||||
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const CooMatmatDescriptor& d =
|
||||
*UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
|
||||
void* coo_values = buffers[0];
|
||||
void* coo_row_ind = buffers[1];
|
||||
void* coo_col_ind = buffers[2];
|
||||
void* Bbuf = buffers[3];
|
||||
void* Cbuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
CudaConst alpha = CudaOne(d.C.type);
|
||||
CudaConst beta = CudaZero(d.C.type);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
ThrowIfErrorStatus(cusparseCreateCoo(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
|
||||
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
ThrowIfErrorStatus(cusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf));
|
||||
|
||||
ThrowIfErrorStatus(cusparseDestroySpMat(mat_a));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
#if JAX_USE_CUDA11
|
||||
dict["cusparse_csr_todense"] = EncapsulateFunction(CsrToDense);
|
||||
dict["cusparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
|
||||
dict["cusparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
|
||||
dict["cusparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
|
||||
dict["cusparse_coo_todense"] = EncapsulateFunction(CooToDense);
|
||||
dict["cusparse_coo_fromdense"] = EncapsulateFunction(CooFromDense);
|
||||
dict["cusparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
|
||||
dict["cusparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
|
||||
#endif
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(cusparse_kernels, m) {
|
||||
m.attr("cusparse_supported") = py::bool_(JAX_USE_CUDA11);
|
||||
m.def("registrations", &Registrations);
|
||||
#if JAX_USE_CUDA11
|
||||
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
|
||||
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
|
||||
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
|
||||
m.def("build_csr_matmat_descriptor", &BuildCsrMatmatDescriptor);
|
||||
m.def("build_coo_todense_descriptor", &BuildCooToDenseDescriptor);
|
||||
m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor);
|
||||
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
|
||||
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
289
jaxlib/cusparse.py
Normal file
289
jaxlib/cusparse.py
Normal file
@ -0,0 +1,289 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
cusparse wrappers for performing sparse matrix computations in JAX
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.lib import xla_client
|
||||
|
||||
try:
|
||||
from . import cusparse_kernels
|
||||
except ImportError:
|
||||
cusparse_kernels = None
|
||||
else:
|
||||
for _name, _value in cusparse_kernels.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
|
||||
|
||||
is_supported : bool = cusparse_kernels and cusparse_kernels.cusparse_supported
|
||||
|
||||
|
||||
_ops = xla_client.ops
|
||||
_Shape = xla_client.Shape
|
||||
|
||||
def csr_todense(c, data, indices, indptr, *, shape):
|
||||
"""CSR to dense matrix."""
|
||||
data_dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(indices).element_type())
|
||||
rows, cols = shape
|
||||
nnz = c.get_shape(data).dimensions()[0]
|
||||
|
||||
buffer_size, opaque = cusparse_kernels.build_csr_todense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
out = xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"cusparse_csr_todense",
|
||||
operands=(data, indices, indptr),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(indices),
|
||||
c.get_shape(indptr),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(data_dtype, shape, (1, 0)),
|
||||
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
|
||||
)),
|
||||
opaque=opaque,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
def csr_fromdense(c, mat, *, nnz, index_dtype):
|
||||
"""CSR from dense matrix."""
|
||||
data_dtype = np.dtype(c.get_shape(mat).element_type())
|
||||
shape = c.get_shape(mat).dimensions()
|
||||
rows, cols = shape
|
||||
|
||||
buffer_size, opaque = cusparse_kernels.build_csr_fromdense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
out = xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"cusparse_csr_fromdense",
|
||||
operands=(mat,),
|
||||
operand_shapes_with_layout=(
|
||||
_Shape.array_shape(data_dtype, shape, (1, 0)),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(data_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (shape[0] + 1,), (0,)),
|
||||
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
|
||||
)),
|
||||
opaque=opaque,
|
||||
)
|
||||
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
|
||||
|
||||
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
|
||||
"""CSR matrix/vector multiply."""
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(indices).element_type())
|
||||
x_dtype = np.dtype(c.get_shape(x).element_type())
|
||||
rows, cols = shape
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = dtype
|
||||
|
||||
buffer_size, opaque = cusparse_kernels.build_csr_matvec_descriptor(
|
||||
dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
out = xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"cusparse_csr_matvec",
|
||||
operands=(data, indices, indptr, x),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(indices),
|
||||
c.get_shape(indptr),
|
||||
c.get_shape(x),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
|
||||
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
|
||||
opaque=opaque,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
|
||||
"""CSR from dense matrix."""
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(indices).element_type())
|
||||
B_dtype = np.dtype(c.get_shape(B).element_type())
|
||||
rows, cols = shape
|
||||
_, Ccols = c.get_shape(B).dimensions()
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = dtype
|
||||
|
||||
buffer_size, opaque = cusparse_kernels.build_csr_matmat_descriptor(
|
||||
dtype, B_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, Ccols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
out = xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"cusparse_csr_matmat",
|
||||
operands=(data, indices, indptr, B),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(indices),
|
||||
c.get_shape(indptr),
|
||||
c.get_shape(B),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
|
||||
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
|
||||
opaque=opaque,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
def coo_todense(c, data, row, col, *, shape):
|
||||
"""COO to dense matrix."""
|
||||
data_dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
rows, cols = shape
|
||||
nnz = c.get_shape(data).dimensions()[0]
|
||||
|
||||
buffer_size, opaque = cusparse_kernels.build_coo_todense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
out = xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"cusparse_coo_todense",
|
||||
operands=(data, row, col),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(row),
|
||||
c.get_shape(col),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(data_dtype, shape, (1, 0)),
|
||||
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
|
||||
)),
|
||||
opaque=opaque,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
def coo_fromdense(c, mat, *, nnz, index_dtype):
|
||||
"""COO from dense matrix."""
|
||||
data_dtype = np.dtype(c.get_shape(mat).element_type())
|
||||
shape = c.get_shape(mat).dimensions()
|
||||
rows, cols = shape
|
||||
|
||||
buffer_size, opaque = cusparse_kernels.build_coo_fromdense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
|
||||
out = xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"cusparse_coo_fromdense",
|
||||
operands=(mat,),
|
||||
operand_shapes_with_layout=(
|
||||
_Shape.array_shape(data_dtype, shape, (1, 0)),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(data_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
|
||||
)),
|
||||
opaque=opaque,
|
||||
)
|
||||
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
|
||||
def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
|
||||
"""CSR matrix/vector multiply."""
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
x_dtype = np.dtype(c.get_shape(x).element_type())
|
||||
rows, cols = shape
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = dtype
|
||||
|
||||
buffer_size, opaque = cusparse_kernels.build_coo_matvec_descriptor(
|
||||
dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
out = xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"cusparse_coo_matvec",
|
||||
operands=(data, row, col, x),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(row),
|
||||
c.get_shape(col),
|
||||
c.get_shape(x),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
|
||||
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
|
||||
opaque=opaque,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
|
||||
"""CSR from dense matrix."""
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
B_dtype = np.dtype(c.get_shape(B).element_type())
|
||||
rows, cols = shape
|
||||
_, Ccols = c.get_shape(B).dimensions()
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = dtype
|
||||
|
||||
buffer_size, opaque = cusparse_kernels.build_coo_matmat_descriptor(
|
||||
dtype, B_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, Ccols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
out = xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"cusparse_coo_matmat",
|
||||
operands=(data, row, col, B),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(row),
|
||||
c.get_shape(col),
|
||||
c.get_shape(B),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
|
||||
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
|
||||
opaque=opaque,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
218
tests/sparse_ops_test.py
Normal file
218
tests/sparse_ops_test.py
Normal file
@ -0,0 +1,218 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from jax import config
|
||||
from jax.experimental import sparse_ops
|
||||
from jax.lib import xla_bridge
|
||||
from jax import jit
|
||||
from jax import test_util as jtu
|
||||
from jax import xla
|
||||
import jax.numpy as jnp
|
||||
|
||||
import numpy as np
|
||||
from scipy import sparse
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def rand_sparse(rng, nnz=0.1, post=lambda x: x):
|
||||
def _rand_sparse(shape, dtype, nnz=nnz):
|
||||
rand = jtu.rand_default(rng)
|
||||
size = np.prod(shape)
|
||||
if 0 <= nnz < 1:
|
||||
nnz = nnz * size
|
||||
nnz = min(size, int(nnz))
|
||||
M = rand(shape, dtype)
|
||||
indices = rng.choice(size, size - nnz, replace=False)
|
||||
M.flat[indices] = 0
|
||||
return post(M)
|
||||
return _rand_sparse
|
||||
|
||||
|
||||
class cuSparseTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
def test_csr_todense(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
|
||||
M = rng(shape, dtype)
|
||||
|
||||
args = (M.data, M.indices, M.indptr)
|
||||
todense = lambda *args: sparse_ops.csr_todense(*args, shape=M.shape)
|
||||
|
||||
self.assertArraysEqual(M.toarray(), todense(*args))
|
||||
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
def test_csr_fromdense(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
M_csr = sparse.csr_matrix(M)
|
||||
|
||||
nnz = M_csr.nnz
|
||||
index_dtype = jnp.int32
|
||||
fromdense = lambda M: sparse_ops.csr_fromdense(M, nnz=nnz, index_dtype=jnp.int32)
|
||||
|
||||
data, indices, indptr = fromdense(M)
|
||||
self.assertArraysEqual(data, M_csr.data.astype(dtype))
|
||||
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
|
||||
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
|
||||
|
||||
data, indices, indptr = jit(fromdense)(M)
|
||||
self.assertArraysEqual(data, M_csr.data.astype(dtype))
|
||||
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
|
||||
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for transpose in [True, False]))
|
||||
def test_csr_matvec(self, shape, dtype, transpose):
|
||||
op = lambda M: M.T if transpose else M
|
||||
|
||||
v_rng = jtu.rand_default(self.rng())
|
||||
rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
|
||||
M = rng(shape, dtype)
|
||||
v = v_rng(op(M).shape[1], dtype)
|
||||
|
||||
args = (M.data, M.indices, M.indptr, v)
|
||||
matvec = lambda *args: sparse_ops.csr_matvec(*args, shape=M.shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ v, matvec(*args))
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for transpose in [True, False]))
|
||||
def test_csr_matmat(self, shape, dtype, transpose):
|
||||
op = lambda M: M.T if transpose else M
|
||||
|
||||
B_rng = jtu.rand_default(self.rng())
|
||||
rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
|
||||
M = rng(shape, dtype)
|
||||
B = B_rng((op(M).shape[1], 4), dtype)
|
||||
|
||||
args = (M.data, M.indices, M.indptr, B)
|
||||
matmat = lambda *args: sparse_ops.csr_matmat(*args, shape=shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ B, matmat(*args))
|
||||
self.assertAllClose(op(M) @ B, jit(matmat)(*args))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
def test_coo_todense(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
|
||||
M = rng(shape, dtype)
|
||||
|
||||
args = (M.data, M.row, M.col)
|
||||
todense = lambda *args: sparse_ops.coo_todense(*args, shape=M.shape)
|
||||
|
||||
self.assertArraysEqual(M.toarray(), todense(*args))
|
||||
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
def test_coo_fromdense(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
M_coo = sparse.coo_matrix(M)
|
||||
|
||||
nnz = M_coo.nnz
|
||||
index_dtype = jnp.int32
|
||||
fromdense = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz, index_dtype=jnp.int32)
|
||||
|
||||
data, row, col = fromdense(M)
|
||||
self.assertArraysEqual(data, M_coo.data.astype(dtype))
|
||||
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
||||
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
||||
|
||||
data, indices, indptr = jit(fromdense)(M)
|
||||
self.assertArraysEqual(data, M_coo.data.astype(dtype))
|
||||
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
||||
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for transpose in [True, False]))
|
||||
def test_coo_matvec(self, shape, dtype, transpose):
|
||||
op = lambda M: M.T if transpose else M
|
||||
|
||||
v_rng = jtu.rand_default(self.rng())
|
||||
rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
|
||||
M = rng(shape, dtype)
|
||||
v = v_rng(op(M).shape[1], dtype)
|
||||
|
||||
args = (M.data, M.row, M.col, v)
|
||||
matvec = lambda *args: sparse_ops.coo_matvec(*args, shape=M.shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ v, matvec(*args))
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for transpose in [True, False]))
|
||||
def test_coo_matmat(self, shape, dtype, transpose):
|
||||
op = lambda M: M.T if transpose else M
|
||||
|
||||
B_rng = jtu.rand_default(self.rng())
|
||||
rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
|
||||
M = rng(shape, dtype)
|
||||
B = B_rng((op(M).shape[1], 4), dtype)
|
||||
|
||||
args = (M.data, M.row, M.col, B)
|
||||
matmat = lambda *args: sparse_ops.coo_matmat(*args, shape=shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ B, matmat(*args))
|
||||
self.assertAllClose(op(M) @ B, jit(matmat)(*args))
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
||||
def test_gpu_translation_rule(self):
|
||||
version = xla_bridge.get_backend().platform_version
|
||||
cuda_version = None if version == "<unknown>" else int(version.split()[-1])
|
||||
if cuda_version is None or cuda_version < 11000:
|
||||
self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
|
||||
else:
|
||||
self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user