[sparse] Improve type safety of cusparse lowerings

Fixes https://github.com/google/jax/issues/8577

PiperOrigin-RevId: 410624036
This commit is contained in:
Jake VanderPlas 2021-11-17 14:04:53 -08:00 committed by jax authors
parent bb3f19891e
commit 7ce5568435
3 changed files with 121 additions and 40 deletions

View File

@ -226,6 +226,7 @@ from .ops import (
csr_todense_p as csr_todense_p,
todense as todense,
todense_p as todense_p,
CuSparseEfficiencyWarning as CuSparseEfficiencyWarning,
COO as COO,
CSC as CSC,
CSR as CSR,

View File

@ -32,6 +32,7 @@ Further down are some examples of potential high-level wrappers for sparse objec
from functools import partial
import operator
from typing import Tuple
import warnings
import numpy as np
@ -50,6 +51,9 @@ import jax.numpy as jnp
xb = xla_bridge
xops = xla_client.ops
class CuSparseEfficiencyWarning(UserWarning):
pass
#--------------------------------------------------------------------
# utilities
# TODO: possibly make these primitives, targeting cusparse rountines
@ -109,8 +113,17 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
assert indptr.shape[0] == shape[0] + 1
return core.ShapedArray(shape, data.dtype)
_csr_todense_translation_rule = xla.lower_fun(
_csr_todense_impl, multiple_results=False, new_style=True)
def _csr_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, *, shape):
dtype = avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"csr_todense cusparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_todense_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, shape=shape)
return [cusparse.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape):
@ -129,8 +142,7 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape):
ad.defjvp(csr_todense_p, _csr_todense_jvp, None, None)
ad.primitive_transposes[csr_todense_p] = _csr_todense_transpose
xla.register_translation(csr_todense_p, xla.lower_fun(
_csr_todense_impl, multiple_results=False, new_style=True))
xla.register_translation(csr_todense_p, _csr_todense_translation_rule)
if cusparse and cusparse.is_supported:
xla.register_translation(csr_todense_p, _csr_todense_gpu_translation_rule,
platform='gpu')
@ -182,8 +194,17 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype):
indptr = core.ShapedArray((mat.shape[0] + 1,), index_dtype)
return data, indices, indptr
_csr_fromdense_translation_rule = xla.lower_fun(
_csr_fromdense_impl, multiple_results=True, new_style=True)
def _csr_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
index_dtype):
dtype = avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"csr_fromdense cusparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_fromdense_translation_rule(ctx, avals_in, avals_out, mat,
nse=nse, index_dtype=index_dtype)
data, indices, indptr = cusparse.csr_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, indices, indptr]
@ -215,8 +236,7 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype):
ad.primitive_jvps[csr_fromdense_p] = _csr_fromdense_jvp
ad.primitive_transposes[csr_fromdense_p] = _csr_fromdense_transpose
xla.register_translation(csr_fromdense_p, xla.lower_fun(
_csr_fromdense_impl, multiple_results=True, new_style=True))
xla.register_translation(csr_fromdense_p, _csr_fromdense_translation_rule)
if cusparse and cusparse.is_supported:
xla.register_translation(csr_fromdense_p,
_csr_fromdense_gpu_translation_rule,
@ -262,8 +282,17 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
assert v.shape[0] == (shape[0] if transpose else shape[1])
return core.ShapedArray((out_shape,), data.dtype)
_csr_matvec_translation_rule = xla.lower_fun(
_csr_matvec_impl, multiple_results=False, new_style=True)
def _csr_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, v, *, shape, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"csr_matvec cusparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matvec_translation_rule(ctx, avals_in, avals_out, data, indices, indptr, v,
shape=shape, transpose=transpose)
return [cusparse.csr_matvec(ctx.builder, data, indices, indptr, v,
shape=shape, transpose=transpose)]
@ -288,8 +317,7 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose):
ad.defjvp(csr_matvec_p, _csr_matvec_jvp_mat, None, None, _csr_matvec_jvp_vec)
ad.primitive_transposes[csr_matvec_p] = _csr_matvec_transpose
xla.register_translation(csr_matvec_p, xla.lower_fun(
_csr_matvec_impl, multiple_results=False, new_style=True))
xla.register_translation(csr_matvec_p, _csr_matvec_translation_rule)
if cusparse and cusparse.is_supported:
xla.register_translation(csr_matvec_p, _csr_matvec_gpu_translation_rule,
platform='gpu')
@ -336,8 +364,17 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
assert B.shape[0] == (shape[0] if transpose else shape[1])
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
_csr_matmat_translation_rule = xla.lower_fun(
_csr_matmat_impl, multiple_results=False, new_style=True)
def _csr_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, B, *, shape, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"csr_matmat cusparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matmat_translation_rule(ctx, avals_in, avals_out, data, indices, indptr, B,
shape=shape, transpose=transpose)
return [cusparse.csr_matmat(ctx.builder, data, indices, indptr, B,
shape=shape, transpose=transpose)]
@ -360,8 +397,7 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose):
ad.defjvp(csr_matmat_p, _csr_matmat_jvp_left, None, None, _csr_matmat_jvp_right)
ad.primitive_transposes[csr_matmat_p] = _csr_matmat_transpose
xla.register_translation(csr_matmat_p, xla.lower_fun(
_csr_matmat_impl, multiple_results=False, new_style=True))
xla.register_translation(csr_matmat_p, _csr_matmat_translation_rule)
if cusparse and cusparse.is_supported:
xla.register_translation(csr_matmat_p, _csr_matmat_gpu_translation_rule,
platform='gpu')
@ -394,8 +430,17 @@ def _coo_todense_impl(data, row, col, *, shape):
def _coo_todense_abstract_eval(data, row, col, *, shape):
return core.ShapedArray(shape, data.dtype)
_coo_todense_translation_rule = xla.lower_fun(
_coo_todense_impl, multiple_results=False, new_style=True)
def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
*, shape):
dtype = avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"coo_todense cusparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col,
shape=shape)
return [cusparse.coo_todense(ctx.builder, data, row, col, shape=shape)]
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
@ -414,8 +459,7 @@ def _coo_todense_transpose(ct, data, row, col, *, shape):
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.register_translation(coo_todense_p, xla.lower_fun(
_coo_todense_impl, multiple_results=False, new_style=True))
xla.register_translation(coo_todense_p, _coo_todense_translation_rule)
if cusparse and cusparse.is_supported:
xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule,
platform='gpu')
@ -462,8 +506,17 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype):
row = col = core.ShapedArray((nse,), index_dtype)
return data, row, col
_coo_fromdense_translation_rule = xla.lower_fun(
_coo_fromdense_impl, multiple_results=True, new_style=True)
def _coo_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
index_dtype):
dtype = avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"coo_fromdense cusparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_fromdense_translation_rule(ctx, avals_in, avals_out, mat,
nse=nse, index_dtype=index_dtype)
data, row, col = cusparse.coo_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, row, col]
@ -496,8 +549,7 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
xla.register_translation(coo_fromdense_p, xla.lower_fun(
_coo_fromdense_impl, multiple_results=True, new_style=True))
xla.register_translation(coo_fromdense_p, _coo_fromdense_translation_rule)
if cusparse and cusparse.is_supported:
xla.register_translation(coo_fromdense_p,
_coo_fromdense_gpu_translation_rule,
@ -547,8 +599,17 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
out_shape = shape[1] if transpose else shape[0]
return core.ShapedArray((out_shape,), data.dtype)
_coo_matvec_translation_rule = xla.lower_fun(
_coo_matvec_impl, multiple_results=False, new_style=True)
def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
v, *, shape, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matvec cusparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v,
shape=shape, transpose=transpose)
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
transpose=transpose)]
@ -572,8 +633,7 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
xla.register_translation(coo_matvec_p, xla.lower_fun(
_coo_matvec_impl, multiple_results=False, new_style=True))
xla.register_translation(coo_matvec_p, _coo_matvec_translation_rule)
if cusparse and cusparse.is_supported:
xla.register_translation(coo_matvec_p, _coo_matvec_gpu_translation_rule,
platform='gpu')
@ -621,8 +681,17 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
out_shape = shape[1] if transpose else shape[0]
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
_coo_matmat_translation_rule = xla.lower_fun(
_coo_matmat_impl, multiple_results=False, new_style=True)
def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
B, *, shape, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matmat cusparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
shape=shape, transpose=transpose)
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
transpose=transpose)]
@ -643,8 +712,7 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, shape, transpose):
ad.defjvp(coo_matmat_p, _coo_matmat_jvp_left, None, None, _coo_matmat_jvp_right)
ad.primitive_transposes[coo_matmat_p] = _coo_matmat_transpose
xla.register_translation(coo_matmat_p, xla.lower_fun(
_coo_matmat_impl, multiple_results=False, new_style=True))
xla.register_translation(coo_matmat_p, _coo_matmat_translation_rule)
if cusparse and cusparse.is_supported:
xla.register_translation(coo_matmat_p, _coo_matmat_gpu_translation_rule,
platform='gpu')

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from functools import partial
import itertools
import operator
@ -38,7 +39,6 @@ from jax._src.lax.lax import remaining, DotDimensionNumbers
from jax import xla
import jax.numpy as jnp
from jax.util import split_list
from jax import jvp
import numpy as np
import scipy.sparse
config.parse_flags_with_absl()
@ -116,11 +116,21 @@ def rand_sparse(rng, nse=0.5, post=lambda x: x):
class cuSparseTest(jtu.JaxTestCase):
def gpu_dense_conversion_warning_context(self, dtype):
if jtu.device_under_test() == "gpu" and np.issubdtype(dtype, np.integer):
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
return contextlib.nullcontext()
def gpu_matmul_warning_context(self, dtype):
if jtu.device_under_test() == "gpu" and dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
return contextlib.nullcontext()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
for dtype in all_dtypes))
def test_csr_todense(self, shape, dtype):
rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
M = rng(shape, dtype)
@ -129,7 +139,8 @@ class cuSparseTest(jtu.JaxTestCase):
todense = lambda *args: sparse.csr_todense(*args, shape=M.shape)
self.assertArraysEqual(M.toarray(), todense(*args))
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
with self.gpu_dense_conversion_warning_context(dtype):
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
@ -242,7 +253,7 @@ class cuSparseTest(jtu.JaxTestCase):
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
for dtype in all_dtypes))
def test_csr_fromdense(self, shape, dtype):
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
@ -257,7 +268,8 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
data, indices, indptr = jit(fromdense)(M)
with self.gpu_dense_conversion_warning_context(dtype):
data, indices, indptr = jit(fromdense)(M)
self.assertArraysEqual(data, M_csr.data.astype(dtype))
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
@ -266,7 +278,7 @@ class cuSparseTest(jtu.JaxTestCase):
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for dtype in all_dtypes
for transpose in [True, False]))
def test_csr_matvec(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
@ -280,13 +292,14 @@ class cuSparseTest(jtu.JaxTestCase):
matvec = lambda *args: sparse.csr_matvec(*args, shape=M.shape, transpose=transpose)
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for dtype in all_dtypes
for transpose in [True, False]))
def test_csr_matmat(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
@ -300,13 +313,14 @@ class cuSparseTest(jtu.JaxTestCase):
matmat = lambda *args: sparse.csr_matmat(*args, shape=shape, transpose=transpose)
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
for dtype in all_dtypes))
def test_coo_todense(self, shape, dtype):
rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
M = rng(shape, dtype)
@ -315,13 +329,14 @@ class cuSparseTest(jtu.JaxTestCase):
todense = lambda *args: sparse.coo_todense(*args, shape=M.shape)
self.assertArraysEqual(M.toarray(), todense(*args))
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
with self.gpu_dense_conversion_warning_context(dtype):
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
for dtype in all_dtypes))
def test_coo_fromdense(self, shape, dtype):
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
@ -336,7 +351,8 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
data, indices, indptr = jit(fromdense)(M)
with self.gpu_dense_conversion_warning_context(dtype):
data, indices, indptr = jit(fromdense)(M)
self.assertArraysEqual(data, M_coo.data.astype(dtype))
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
@ -345,7 +361,7 @@ class cuSparseTest(jtu.JaxTestCase):
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for dtype in all_dtypes
for transpose in [True, False]))
def test_coo_matvec(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
@ -359,13 +375,14 @@ class cuSparseTest(jtu.JaxTestCase):
matvec = lambda *args: sparse.coo_matvec(*args, shape=M.shape, transpose=transpose)
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for dtype in all_dtypes
for transpose in [True, False]))
def test_coo_matmat(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
@ -379,13 +396,8 @@ class cuSparseTest(jtu.JaxTestCase):
matmat = lambda *args: sparse.coo_matmat(*args, shape=shape, transpose=transpose)
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
y, dy = jvp(lambda x: sparse.coo_matmat(M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(), (B, ), (jnp.ones_like(B), ))
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
y, dy = jvp(lambda x: sparse.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
def test_coo_matmat_layout(self):
# Regression test for https://github.com/google/jax/issues/7533
@ -546,7 +558,7 @@ class BCOOTest(jtu.JaxTestCase):
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
for dtype in all_dtypes
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):