mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] Improve type safety of cusparse lowerings
Fixes https://github.com/google/jax/issues/8577 PiperOrigin-RevId: 410624036
This commit is contained in:
parent
bb3f19891e
commit
7ce5568435
@ -226,6 +226,7 @@ from .ops import (
|
||||
csr_todense_p as csr_todense_p,
|
||||
todense as todense,
|
||||
todense_p as todense_p,
|
||||
CuSparseEfficiencyWarning as CuSparseEfficiencyWarning,
|
||||
COO as COO,
|
||||
CSC as CSC,
|
||||
CSR as CSR,
|
||||
|
@ -32,6 +32,7 @@ Further down are some examples of potential high-level wrappers for sparse objec
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Tuple
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -50,6 +51,9 @@ import jax.numpy as jnp
|
||||
xb = xla_bridge
|
||||
xops = xla_client.ops
|
||||
|
||||
class CuSparseEfficiencyWarning(UserWarning):
|
||||
pass
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# utilities
|
||||
# TODO: possibly make these primitives, targeting cusparse rountines
|
||||
@ -109,8 +113,17 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
|
||||
assert indptr.shape[0] == shape[0] + 1
|
||||
return core.ShapedArray(shape, data.dtype)
|
||||
|
||||
_csr_todense_translation_rule = xla.lower_fun(
|
||||
_csr_todense_impl, multiple_results=False, new_style=True)
|
||||
|
||||
def _csr_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
indptr, *, shape):
|
||||
dtype = avals_in[0].dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"csr_todense cusparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_todense_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
indptr, shape=shape)
|
||||
return [cusparse.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
|
||||
|
||||
def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape):
|
||||
@ -129,8 +142,7 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape):
|
||||
|
||||
ad.defjvp(csr_todense_p, _csr_todense_jvp, None, None)
|
||||
ad.primitive_transposes[csr_todense_p] = _csr_todense_transpose
|
||||
xla.register_translation(csr_todense_p, xla.lower_fun(
|
||||
_csr_todense_impl, multiple_results=False, new_style=True))
|
||||
xla.register_translation(csr_todense_p, _csr_todense_translation_rule)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.register_translation(csr_todense_p, _csr_todense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
@ -182,8 +194,17 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype):
|
||||
indptr = core.ShapedArray((mat.shape[0] + 1,), index_dtype)
|
||||
return data, indices, indptr
|
||||
|
||||
_csr_fromdense_translation_rule = xla.lower_fun(
|
||||
_csr_fromdense_impl, multiple_results=True, new_style=True)
|
||||
|
||||
def _csr_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
|
||||
index_dtype):
|
||||
dtype = avals_in[0].dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"csr_fromdense cusparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_fromdense_translation_rule(ctx, avals_in, avals_out, mat,
|
||||
nse=nse, index_dtype=index_dtype)
|
||||
data, indices, indptr = cusparse.csr_fromdense(
|
||||
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
||||
return [data, indices, indptr]
|
||||
@ -215,8 +236,7 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype):
|
||||
|
||||
ad.primitive_jvps[csr_fromdense_p] = _csr_fromdense_jvp
|
||||
ad.primitive_transposes[csr_fromdense_p] = _csr_fromdense_transpose
|
||||
xla.register_translation(csr_fromdense_p, xla.lower_fun(
|
||||
_csr_fromdense_impl, multiple_results=True, new_style=True))
|
||||
xla.register_translation(csr_fromdense_p, _csr_fromdense_translation_rule)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.register_translation(csr_fromdense_p,
|
||||
_csr_fromdense_gpu_translation_rule,
|
||||
@ -262,8 +282,17 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
|
||||
assert v.shape[0] == (shape[0] if transpose else shape[1])
|
||||
return core.ShapedArray((out_shape,), data.dtype)
|
||||
|
||||
_csr_matvec_translation_rule = xla.lower_fun(
|
||||
_csr_matvec_impl, multiple_results=False, new_style=True)
|
||||
|
||||
def _csr_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
indptr, v, *, shape, transpose):
|
||||
dtype = avals_in[0].dtype
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"csr_matvec cusparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_matvec_translation_rule(ctx, avals_in, avals_out, data, indices, indptr, v,
|
||||
shape=shape, transpose=transpose)
|
||||
return [cusparse.csr_matvec(ctx.builder, data, indices, indptr, v,
|
||||
shape=shape, transpose=transpose)]
|
||||
|
||||
@ -288,8 +317,7 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose):
|
||||
|
||||
ad.defjvp(csr_matvec_p, _csr_matvec_jvp_mat, None, None, _csr_matvec_jvp_vec)
|
||||
ad.primitive_transposes[csr_matvec_p] = _csr_matvec_transpose
|
||||
xla.register_translation(csr_matvec_p, xla.lower_fun(
|
||||
_csr_matvec_impl, multiple_results=False, new_style=True))
|
||||
xla.register_translation(csr_matvec_p, _csr_matvec_translation_rule)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.register_translation(csr_matvec_p, _csr_matvec_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
@ -336,8 +364,17 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
|
||||
assert B.shape[0] == (shape[0] if transpose else shape[1])
|
||||
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
||||
|
||||
_csr_matmat_translation_rule = xla.lower_fun(
|
||||
_csr_matmat_impl, multiple_results=False, new_style=True)
|
||||
|
||||
def _csr_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
indptr, B, *, shape, transpose):
|
||||
dtype = avals_in[0].dtype
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"csr_matmat cusparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_matmat_translation_rule(ctx, avals_in, avals_out, data, indices, indptr, B,
|
||||
shape=shape, transpose=transpose)
|
||||
return [cusparse.csr_matmat(ctx.builder, data, indices, indptr, B,
|
||||
shape=shape, transpose=transpose)]
|
||||
|
||||
@ -360,8 +397,7 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose):
|
||||
|
||||
ad.defjvp(csr_matmat_p, _csr_matmat_jvp_left, None, None, _csr_matmat_jvp_right)
|
||||
ad.primitive_transposes[csr_matmat_p] = _csr_matmat_transpose
|
||||
xla.register_translation(csr_matmat_p, xla.lower_fun(
|
||||
_csr_matmat_impl, multiple_results=False, new_style=True))
|
||||
xla.register_translation(csr_matmat_p, _csr_matmat_translation_rule)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.register_translation(csr_matmat_p, _csr_matmat_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
@ -394,8 +430,17 @@ def _coo_todense_impl(data, row, col, *, shape):
|
||||
def _coo_todense_abstract_eval(data, row, col, *, shape):
|
||||
return core.ShapedArray(shape, data.dtype)
|
||||
|
||||
_coo_todense_translation_rule = xla.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False, new_style=True)
|
||||
|
||||
def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
*, shape):
|
||||
dtype = avals_in[0].dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"coo_todense cusparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
shape=shape)
|
||||
return [cusparse.coo_todense(ctx.builder, data, row, col, shape=shape)]
|
||||
|
||||
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
|
||||
@ -414,8 +459,7 @@ def _coo_todense_transpose(ct, data, row, col, *, shape):
|
||||
|
||||
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
|
||||
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
|
||||
xla.register_translation(coo_todense_p, xla.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False, new_style=True))
|
||||
xla.register_translation(coo_todense_p, _coo_todense_translation_rule)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
@ -462,8 +506,17 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype):
|
||||
row = col = core.ShapedArray((nse,), index_dtype)
|
||||
return data, row, col
|
||||
|
||||
_coo_fromdense_translation_rule = xla.lower_fun(
|
||||
_coo_fromdense_impl, multiple_results=True, new_style=True)
|
||||
|
||||
def _coo_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
|
||||
index_dtype):
|
||||
dtype = avals_in[0].dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"coo_fromdense cusparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_fromdense_translation_rule(ctx, avals_in, avals_out, mat,
|
||||
nse=nse, index_dtype=index_dtype)
|
||||
data, row, col = cusparse.coo_fromdense(
|
||||
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
||||
return [data, row, col]
|
||||
@ -496,8 +549,7 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):
|
||||
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
|
||||
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
|
||||
|
||||
xla.register_translation(coo_fromdense_p, xla.lower_fun(
|
||||
_coo_fromdense_impl, multiple_results=True, new_style=True))
|
||||
xla.register_translation(coo_fromdense_p, _coo_fromdense_translation_rule)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.register_translation(coo_fromdense_p,
|
||||
_coo_fromdense_gpu_translation_rule,
|
||||
@ -547,8 +599,17 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
return core.ShapedArray((out_shape,), data.dtype)
|
||||
|
||||
_coo_matvec_translation_rule = xla.lower_fun(
|
||||
_coo_matvec_impl, multiple_results=False, new_style=True)
|
||||
|
||||
def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
v, *, shape, transpose):
|
||||
dtype = avals_in[0].dtype
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"coo_matvec cusparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v,
|
||||
shape=shape, transpose=transpose)
|
||||
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
|
||||
transpose=transpose)]
|
||||
|
||||
@ -572,8 +633,7 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
|
||||
|
||||
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
|
||||
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
|
||||
xla.register_translation(coo_matvec_p, xla.lower_fun(
|
||||
_coo_matvec_impl, multiple_results=False, new_style=True))
|
||||
xla.register_translation(coo_matvec_p, _coo_matvec_translation_rule)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.register_translation(coo_matvec_p, _coo_matvec_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
@ -621,8 +681,17 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
||||
|
||||
_coo_matmat_translation_rule = xla.lower_fun(
|
||||
_coo_matmat_impl, multiple_results=False, new_style=True)
|
||||
|
||||
def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
B, *, shape, transpose):
|
||||
dtype = avals_in[0].dtype
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"coo_matmat cusparse lowering not available for dtype={dtype}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
|
||||
shape=shape, transpose=transpose)
|
||||
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
|
||||
transpose=transpose)]
|
||||
|
||||
@ -643,8 +712,7 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, shape, transpose):
|
||||
|
||||
ad.defjvp(coo_matmat_p, _coo_matmat_jvp_left, None, None, _coo_matmat_jvp_right)
|
||||
ad.primitive_transposes[coo_matmat_p] = _coo_matmat_transpose
|
||||
xla.register_translation(coo_matmat_p, xla.lower_fun(
|
||||
_coo_matmat_impl, multiple_results=False, new_style=True))
|
||||
xla.register_translation(coo_matmat_p, _coo_matmat_translation_rule)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.register_translation(coo_matmat_p, _coo_matmat_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import itertools
|
||||
import operator
|
||||
@ -38,7 +39,6 @@ from jax._src.lax.lax import remaining, DotDimensionNumbers
|
||||
from jax import xla
|
||||
import jax.numpy as jnp
|
||||
from jax.util import split_list
|
||||
from jax import jvp
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
config.parse_flags_with_absl()
|
||||
@ -116,11 +116,21 @@ def rand_sparse(rng, nse=0.5, post=lambda x: x):
|
||||
|
||||
|
||||
class cuSparseTest(jtu.JaxTestCase):
|
||||
def gpu_dense_conversion_warning_context(self, dtype):
|
||||
if jtu.device_under_test() == "gpu" and np.issubdtype(dtype, np.integer):
|
||||
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def gpu_matmul_warning_context(self, dtype):
|
||||
if jtu.device_under_test() == "gpu" and dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
for dtype in all_dtypes))
|
||||
def test_csr_todense(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
|
||||
M = rng(shape, dtype)
|
||||
@ -129,7 +139,8 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
todense = lambda *args: sparse.csr_todense(*args, shape=M.shape)
|
||||
|
||||
self.assertArraysEqual(M.toarray(), todense(*args))
|
||||
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
||||
with self.gpu_dense_conversion_warning_context(dtype):
|
||||
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
@ -242,7 +253,7 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
for dtype in all_dtypes))
|
||||
def test_csr_fromdense(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
@ -257,7 +268,8 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
|
||||
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
|
||||
|
||||
data, indices, indptr = jit(fromdense)(M)
|
||||
with self.gpu_dense_conversion_warning_context(dtype):
|
||||
data, indices, indptr = jit(fromdense)(M)
|
||||
self.assertArraysEqual(data, M_csr.data.astype(dtype))
|
||||
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
|
||||
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
|
||||
@ -266,7 +278,7 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for dtype in all_dtypes
|
||||
for transpose in [True, False]))
|
||||
def test_csr_matvec(self, shape, dtype, transpose):
|
||||
op = lambda M: M.T if transpose else M
|
||||
@ -280,13 +292,14 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
matvec = lambda *args: sparse.csr_matvec(*args, shape=M.shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
||||
with self.gpu_matmul_warning_context(dtype):
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for dtype in all_dtypes
|
||||
for transpose in [True, False]))
|
||||
def test_csr_matmat(self, shape, dtype, transpose):
|
||||
op = lambda M: M.T if transpose else M
|
||||
@ -300,13 +313,14 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
matmat = lambda *args: sparse.csr_matmat(*args, shape=shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
|
||||
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
||||
with self.gpu_matmul_warning_context(dtype):
|
||||
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
for dtype in all_dtypes))
|
||||
def test_coo_todense(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
|
||||
M = rng(shape, dtype)
|
||||
@ -315,13 +329,14 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
todense = lambda *args: sparse.coo_todense(*args, shape=M.shape)
|
||||
|
||||
self.assertArraysEqual(M.toarray(), todense(*args))
|
||||
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
||||
with self.gpu_dense_conversion_warning_context(dtype):
|
||||
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
for dtype in all_dtypes))
|
||||
def test_coo_fromdense(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
@ -336,7 +351,8 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
||||
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
||||
|
||||
data, indices, indptr = jit(fromdense)(M)
|
||||
with self.gpu_dense_conversion_warning_context(dtype):
|
||||
data, indices, indptr = jit(fromdense)(M)
|
||||
self.assertArraysEqual(data, M_coo.data.astype(dtype))
|
||||
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
||||
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
||||
@ -345,7 +361,7 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for dtype in all_dtypes
|
||||
for transpose in [True, False]))
|
||||
def test_coo_matvec(self, shape, dtype, transpose):
|
||||
op = lambda M: M.T if transpose else M
|
||||
@ -359,13 +375,14 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
matvec = lambda *args: sparse.coo_matvec(*args, shape=M.shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
||||
with self.gpu_matmul_warning_context(dtype):
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for dtype in all_dtypes
|
||||
for transpose in [True, False]))
|
||||
def test_coo_matmat(self, shape, dtype, transpose):
|
||||
op = lambda M: M.T if transpose else M
|
||||
@ -379,13 +396,8 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
matmat = lambda *args: sparse.coo_matmat(*args, shape=shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
|
||||
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
y, dy = jvp(lambda x: sparse.coo_matmat(M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(), (B, ), (jnp.ones_like(B), ))
|
||||
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
|
||||
|
||||
y, dy = jvp(lambda x: sparse.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
|
||||
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
|
||||
with self.gpu_matmul_warning_context(dtype):
|
||||
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
def test_coo_matmat_layout(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7533
|
||||
@ -546,7 +558,7 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for dtype in all_dtypes
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)))
|
||||
def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):
|
||||
|
Loading…
x
Reference in New Issue
Block a user