Merge pull request #13280 from jakevdp:sparse-test-util

PiperOrigin-RevId: 489019436
This commit is contained in:
jax authors 2022-11-16 13:13:04 -08:00
commit 58af581e3b

View File

@ -974,20 +974,6 @@ class BCOOTest(sptu.SparseTestCase):
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
self._CompileAndCheckSparse(sparse_func, args_maker)
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
for shape in [(2,), (3, 4), (5, 6, 2)]
for n_batch in range(len(shape) + 1)
for n_dense in [0] # TODO(jakevdp): add tests with n_dense
],
dtype=jtu.dtypes.numeric,
)
def test_bcoo_iter(self, shape, dtype, n_batch, n_dense):
sprng = rand_sparse(self.rng())
M = sprng(shape, dtype)
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
self.assertAllClose(list(M), [row.todense() for row in Msp])
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
@ -1068,28 +1054,17 @@ class BCOOTest(sptu.SparseTestCase):
)
@jax.default_matmul_precision("float32")
def test_bcoo_dot_general(self, props: BcooDotGeneralProperties):
rng = jtu.rand_small(self.rng())
rng_sparse = rand_sparse(self.rng())
rng = jtu.rand_default(self.rng())
sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense)
args_maker = lambda: [sprng(props.lhs_shape, props.dtype),
rng(props.rhs_shape, props.dtype)]
dense_fun = partial(lax.dot_general, dimension_numbers=props.dimension_numbers)
sparse_fun = partial(sparse.bcoo_dot_general, dimension_numbers=props.dimension_numbers)
def args_maker():
lhs = rng_sparse(props.lhs_shape, props.dtype)
rhs = rng(props.rhs_shape, props.dtype)
nse = sparse.util._count_stored_elements(lhs, n_batch=props.n_batch,
n_dense=props.n_dense)
data, indices = sparse_bcoo._bcoo_fromdense(lhs, nse=nse, n_batch=props.n_batch, n_dense=props.n_dense)
return data, indices, lhs, rhs
def f_dense(data, indices, lhs, rhs):
return lax.dot_general(lhs, rhs, dimension_numbers=props.dimension_numbers)
def f_sparse(data, indices, lhs, rhs):
return sparse_bcoo._bcoo_dot_general(data, indices, rhs, lhs_spinfo=BCOOInfo(lhs.shape),
dimension_numbers=props.dimension_numbers)
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
# self._CompileAndCheck(f_sparse, args_maker)
tol = {np.float64: 1E-12, np.complex128: 1E-12,
np.float32: 1E-5, np.complex64: 1E-5}
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
self._CompileAndCheckSparse(sparse_fun, args_maker, atol=tol, rtol=tol)
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
@ -1302,33 +1277,18 @@ class BCOOTest(sptu.SparseTestCase):
)
@jax.default_matmul_precision("float32")
def test_bcoo_rdot_general(self, props: BcooDotGeneralProperties):
rng = jtu.rand_small(self.rng())
rng_sparse = rand_sparse(self.rng())
lhs_shape, rhs_shape = props.rhs_shape, props.lhs_shape
rng = jtu.rand_default(self.rng())
sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense)
args_maker = lambda: [rng(props.rhs_shape, props.dtype),
sprng(props.lhs_shape, props.dtype)]
dimension_numbers = tuple(d[::-1] for d in props.dimension_numbers)
sparse_fun = partial(sparse.bcoo_dot_general, dimension_numbers=dimension_numbers)
dense_fun = partial(lax.dot_general, dimension_numbers=dimension_numbers)
def args_maker():
lhs = rng_sparse(lhs_shape, props.dtype)
rhs = rng(rhs_shape, props.dtype)
nse = sparse.util._count_stored_elements(rhs, n_batch=props.n_batch,
n_dense=props.n_dense)
data, indices = sparse_bcoo._bcoo_fromdense(
rhs, nse=nse, n_batch=props.n_batch, n_dense=props.n_dense)
return data, indices, lhs, rhs
def f_dense(data, indices, lhs, rhs):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
def f_sparse(data, indices, lhs, rhs):
return sparse_bcoo._bcoo_rdot_general(lhs, data, indices,
rhs_spinfo=BCOOInfo(rhs.shape),
dimension_numbers=dimension_numbers)
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
# self._CompileAndCheck(f_sparse, args_maker)
tol = {np.float64: 1E-12, np.complex128: 1E-12,
np.float32: 1E-5, np.complex64: 1E-5}
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
self._CompileAndCheckSparse(sparse_fun, args_maker, atol=tol, rtol=tol)
@jtu.sample_product(
[dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape,
@ -1829,6 +1789,21 @@ class BCOOTest(sptu.SparseTestCase):
self._CheckAgainstDense(fun, fun, args_maker)
self._CompileAndCheckSparse(fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
for shape in [(2,), (3, 4), (5, 6, 2)]
for n_batch in range(len(shape) + 1)
for n_dense in [0] # TODO(jakevdp): add tests with n_dense
],
dtype=jtu.dtypes.numeric,
)
def test_bcoo_iter(self, shape, dtype, n_batch, n_dense):
sprng = rand_sparse(self.rng())
args_maker = lambda: [sprng(shape, dtype)]
self._CheckAgainstDense(list, list, args_maker)
self._CompileAndCheckSparse(list, args_maker)
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense, nse=nse)
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
@ -2011,17 +1986,13 @@ class BCOOTest(sptu.SparseTestCase):
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
n_dense=n_dense)
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)
data_out, indices_out, shape_out = sparse_bcoo._bcoo_reduce_sum(
data, indices, spinfo=BCOOInfo(shape), axes=axes)
result_dense = M.sum(axes)
result_sparse = sparse_bcoo._bcoo_todense(data_out, indices_out, spinfo=BCOOInfo(shape_out))
tol = {np.float32: 1E-6, np.float64: 1E-14}
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense)
args_maker = lambda: [sprng(shape, dtype)]
sparse_fun = partial(sparse.bcoo_reduce_sum, axes=axes)
dense_fun = partial(lambda x: x.sum(axes))
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker)
self._CompileAndCheckSparse(sparse_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, dimensions=dimensions, n_batch=n_batch, n_dense=n_dense)
@ -2106,24 +2077,27 @@ class BCOOTest(sptu.SparseTestCase):
)
@jax.default_matmul_precision("float32")
def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
rhs = jnp.array(rng(rhs_shape, rhs_dtype))
# Note: currently, batch dimensions in matmul must correspond to batch
# dimensions in the sparse representation.
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=max(0, len(lhs_shape) - 2))
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=max(0, len(rhs_shape) - 2))
n_batch_lhs = max(0, len(lhs_shape) - 2)
n_batch_rhs = max(0, len(rhs_shape) - 2)
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
out1 = lhs @ rhs
out2 = lhs_sp @ rhs
out3 = lhs @ rhs_sp
rng = jtu.rand_default(self.rng())
sprng = sptu.rand_bcoo(self.rng())
args_maker_de_sp = lambda: [jnp.array(rng(lhs_shape, lhs_dtype)),
sprng(rhs_shape, rhs_dtype, n_batch=n_batch_rhs)]
args_maker_sp_de = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=n_batch_lhs),
jnp.array(rng(rhs_shape, rhs_dtype))]
tol = {np.float64: 1E-13, np.complex128: 1E-13,
np.float32: 1E-6, np.complex64: 1E-6}
self.assertAllClose(out1, out2, rtol=tol)
self.assertAllClose(out1, out3, rtol=tol)
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_de_sp, tol=tol)
self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_sp_de, tol=tol)
self._CompileAndCheckSparse(operator.matmul, args_maker_de_sp, rtol=tol, atol=tol)
self._CompileAndCheckSparse(operator.matmul, args_maker_sp_de, rtol=tol, atol=tol)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, n_batch=n_batch,
@ -2139,22 +2113,21 @@ class BCOOTest(sptu.SparseTestCase):
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_bcoo_mul_dense(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
rng_lhs = rand_sparse(self.rng())
rng_rhs = jtu.rand_default(self.rng())
lhs = jnp.array(rng_lhs(lhs_shape, lhs_dtype))
rhs = jnp.array(rng_rhs(rhs_shape, rhs_dtype))
rng = jtu.rand_default(self.rng())
sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense)
sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
out1 = lhs * rhs
out2 = (sp(lhs) * rhs).todense()
out3 = (rhs * sp(lhs)).todense()
args_maker_sp_de = lambda: [sprng(lhs_shape, lhs_dtype), jnp.array(rng(rhs_shape, rhs_dtype))]
args_maker_de_sp = lambda: [jnp.array(rng(rhs_shape, rhs_dtype)), sprng(lhs_shape, lhs_dtype)]
tol = {np.float64: 1E-13, np.complex128: 1E-13,
np.float32: 1E-6, np.complex64: 1E-6}
self.assertAllClose(out1, out2, rtol=tol)
self.assertAllClose(out1, out3, rtol=tol)
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstDense(operator.mul, operator.mul, args_maker_de_sp, tol=tol)
self._CheckAgainstDense(operator.mul, operator.mul, args_maker_sp_de, tol=tol)
self._CompileAndCheckSparse(operator.mul, args_maker_de_sp, rtol=tol, atol=tol)
self._CompileAndCheckSparse(operator.mul, args_maker_sp_de, rtol=tol, atol=tol)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, lhs_n_batch=lhs_n_batch,
@ -2173,20 +2146,16 @@ class BCOOTest(sptu.SparseTestCase):
rhs_dtype=all_dtypes,
)
def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n_batch, rhs_n_batch, n_dense):
rng = rand_sparse(self.rng())
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
rhs = jnp.array(rng(rhs_shape, rhs_dtype))
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=lhs_n_batch, n_dense=n_dense)
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=rhs_n_batch, n_dense=n_dense)
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
out1 = lhs * rhs
out2 = (lhs_sp * rhs_sp).todense()
sprng = sptu.rand_bcoo(self.rng(), n_dense=n_dense)
args_maker = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=lhs_n_batch),
sprng(rhs_shape, rhs_dtype, n_batch=rhs_n_batch)]
tol = {np.float64: 1E-13, np.complex128: 1E-13,
np.float32: 1E-6, np.complex64: 1E-6}
self.assertAllClose(out1, out2, rtol=tol)
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol)
self._CompileAndCheckSparse(operator.mul, args_maker, atol=tol, rtol=tol)
def test_bcoo_mul_sparse_with_duplicates(self):
# Regression test for https://github.com/google/jax/issues/8888