mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #13280 from jakevdp:sparse-test-util
PiperOrigin-RevId: 489019436
This commit is contained in:
commit
58af581e3b
@ -974,20 +974,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_func, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
|
||||
for shape in [(2,), (3, 4), (5, 6, 2)]
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in [0] # TODO(jakevdp): add tests with n_dense
|
||||
],
|
||||
dtype=jtu.dtypes.numeric,
|
||||
)
|
||||
def test_bcoo_iter(self, shape, dtype, n_batch, n_dense):
|
||||
sprng = rand_sparse(self.rng())
|
||||
M = sprng(shape, dtype)
|
||||
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
||||
self.assertAllClose(list(M), [row.todense() for row in Msp])
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
@ -1068,28 +1054,17 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_bcoo_dot_general(self, props: BcooDotGeneralProperties):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
rng_sparse = rand_sparse(self.rng())
|
||||
rng = jtu.rand_default(self.rng())
|
||||
sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense)
|
||||
args_maker = lambda: [sprng(props.lhs_shape, props.dtype),
|
||||
rng(props.rhs_shape, props.dtype)]
|
||||
dense_fun = partial(lax.dot_general, dimension_numbers=props.dimension_numbers)
|
||||
sparse_fun = partial(sparse.bcoo_dot_general, dimension_numbers=props.dimension_numbers)
|
||||
|
||||
def args_maker():
|
||||
lhs = rng_sparse(props.lhs_shape, props.dtype)
|
||||
rhs = rng(props.rhs_shape, props.dtype)
|
||||
nse = sparse.util._count_stored_elements(lhs, n_batch=props.n_batch,
|
||||
n_dense=props.n_dense)
|
||||
data, indices = sparse_bcoo._bcoo_fromdense(lhs, nse=nse, n_batch=props.n_batch, n_dense=props.n_dense)
|
||||
return data, indices, lhs, rhs
|
||||
|
||||
def f_dense(data, indices, lhs, rhs):
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=props.dimension_numbers)
|
||||
|
||||
def f_sparse(data, indices, lhs, rhs):
|
||||
return sparse_bcoo._bcoo_dot_general(data, indices, rhs, lhs_spinfo=BCOOInfo(lhs.shape),
|
||||
dimension_numbers=props.dimension_numbers)
|
||||
|
||||
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
|
||||
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
|
||||
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
|
||||
# self._CompileAndCheck(f_sparse, args_maker)
|
||||
tol = {np.float64: 1E-12, np.complex128: 1E-12,
|
||||
np.float32: 1E-5, np.complex64: 1E-5}
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheckSparse(sparse_fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
|
||||
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
||||
@ -1302,33 +1277,18 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_bcoo_rdot_general(self, props: BcooDotGeneralProperties):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
rng_sparse = rand_sparse(self.rng())
|
||||
|
||||
lhs_shape, rhs_shape = props.rhs_shape, props.lhs_shape
|
||||
rng = jtu.rand_default(self.rng())
|
||||
sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense)
|
||||
args_maker = lambda: [rng(props.rhs_shape, props.dtype),
|
||||
sprng(props.lhs_shape, props.dtype)]
|
||||
dimension_numbers = tuple(d[::-1] for d in props.dimension_numbers)
|
||||
sparse_fun = partial(sparse.bcoo_dot_general, dimension_numbers=dimension_numbers)
|
||||
dense_fun = partial(lax.dot_general, dimension_numbers=dimension_numbers)
|
||||
|
||||
def args_maker():
|
||||
lhs = rng_sparse(lhs_shape, props.dtype)
|
||||
rhs = rng(rhs_shape, props.dtype)
|
||||
nse = sparse.util._count_stored_elements(rhs, n_batch=props.n_batch,
|
||||
n_dense=props.n_dense)
|
||||
data, indices = sparse_bcoo._bcoo_fromdense(
|
||||
rhs, nse=nse, n_batch=props.n_batch, n_dense=props.n_dense)
|
||||
return data, indices, lhs, rhs
|
||||
|
||||
def f_dense(data, indices, lhs, rhs):
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
||||
|
||||
def f_sparse(data, indices, lhs, rhs):
|
||||
return sparse_bcoo._bcoo_rdot_general(lhs, data, indices,
|
||||
rhs_spinfo=BCOOInfo(rhs.shape),
|
||||
dimension_numbers=dimension_numbers)
|
||||
|
||||
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
|
||||
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
|
||||
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
|
||||
# self._CompileAndCheck(f_sparse, args_maker)
|
||||
tol = {np.float64: 1E-12, np.complex128: 1E-12,
|
||||
np.float32: 1E-5, np.complex64: 1E-5}
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheckSparse(sparse_fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape,
|
||||
@ -1829,6 +1789,21 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self._CheckAgainstDense(fun, fun, args_maker)
|
||||
self._CompileAndCheckSparse(fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
|
||||
for shape in [(2,), (3, 4), (5, 6, 2)]
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in [0] # TODO(jakevdp): add tests with n_dense
|
||||
],
|
||||
dtype=jtu.dtypes.numeric,
|
||||
)
|
||||
def test_bcoo_iter(self, shape, dtype, n_batch, n_dense):
|
||||
sprng = rand_sparse(self.rng())
|
||||
args_maker = lambda: [sprng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstDense(list, list, args_maker)
|
||||
self._CompileAndCheckSparse(list, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense, nse=nse)
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
@ -2011,17 +1986,13 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
)
|
||||
def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
|
||||
n_dense=n_dense)
|
||||
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)
|
||||
data_out, indices_out, shape_out = sparse_bcoo._bcoo_reduce_sum(
|
||||
data, indices, spinfo=BCOOInfo(shape), axes=axes)
|
||||
result_dense = M.sum(axes)
|
||||
result_sparse = sparse_bcoo._bcoo_todense(data_out, indices_out, spinfo=BCOOInfo(shape_out))
|
||||
tol = {np.float32: 1E-6, np.float64: 1E-14}
|
||||
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
|
||||
sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense)
|
||||
args_maker = lambda: [sprng(shape, dtype)]
|
||||
sparse_fun = partial(sparse.bcoo_reduce_sum, axes=axes)
|
||||
dense_fun = partial(lambda x: x.sum(axes))
|
||||
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, dimensions=dimensions, n_batch=n_batch, n_dense=n_dense)
|
||||
@ -2106,24 +2077,27 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
|
||||
rhs = jnp.array(rng(rhs_shape, rhs_dtype))
|
||||
|
||||
# Note: currently, batch dimensions in matmul must correspond to batch
|
||||
# dimensions in the sparse representation.
|
||||
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=max(0, len(lhs_shape) - 2))
|
||||
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=max(0, len(rhs_shape) - 2))
|
||||
n_batch_lhs = max(0, len(lhs_shape) - 2)
|
||||
n_batch_rhs = max(0, len(rhs_shape) - 2)
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
out1 = lhs @ rhs
|
||||
out2 = lhs_sp @ rhs
|
||||
out3 = lhs @ rhs_sp
|
||||
rng = jtu.rand_default(self.rng())
|
||||
sprng = sptu.rand_bcoo(self.rng())
|
||||
args_maker_de_sp = lambda: [jnp.array(rng(lhs_shape, lhs_dtype)),
|
||||
sprng(rhs_shape, rhs_dtype, n_batch=n_batch_rhs)]
|
||||
args_maker_sp_de = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=n_batch_lhs),
|
||||
jnp.array(rng(rhs_shape, rhs_dtype))]
|
||||
|
||||
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
||||
np.float32: 1E-6, np.complex64: 1E-6}
|
||||
self.assertAllClose(out1, out2, rtol=tol)
|
||||
self.assertAllClose(out1, out3, rtol=tol)
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_de_sp, tol=tol)
|
||||
self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_sp_de, tol=tol)
|
||||
|
||||
self._CompileAndCheckSparse(operator.matmul, args_maker_de_sp, rtol=tol, atol=tol)
|
||||
self._CompileAndCheckSparse(operator.matmul, args_maker_sp_de, rtol=tol, atol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, n_batch=n_batch,
|
||||
@ -2139,22 +2113,21 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
)
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def test_bcoo_mul_dense(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
|
||||
rng_lhs = rand_sparse(self.rng())
|
||||
rng_rhs = jtu.rand_default(self.rng())
|
||||
lhs = jnp.array(rng_lhs(lhs_shape, lhs_dtype))
|
||||
rhs = jnp.array(rng_rhs(rhs_shape, rhs_dtype))
|
||||
rng = jtu.rand_default(self.rng())
|
||||
sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense)
|
||||
|
||||
sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
out1 = lhs * rhs
|
||||
out2 = (sp(lhs) * rhs).todense()
|
||||
out3 = (rhs * sp(lhs)).todense()
|
||||
args_maker_sp_de = lambda: [sprng(lhs_shape, lhs_dtype), jnp.array(rng(rhs_shape, rhs_dtype))]
|
||||
args_maker_de_sp = lambda: [jnp.array(rng(rhs_shape, rhs_dtype)), sprng(lhs_shape, lhs_dtype)]
|
||||
|
||||
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
||||
np.float32: 1E-6, np.complex64: 1E-6}
|
||||
self.assertAllClose(out1, out2, rtol=tol)
|
||||
self.assertAllClose(out1, out3, rtol=tol)
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstDense(operator.mul, operator.mul, args_maker_de_sp, tol=tol)
|
||||
self._CheckAgainstDense(operator.mul, operator.mul, args_maker_sp_de, tol=tol)
|
||||
|
||||
self._CompileAndCheckSparse(operator.mul, args_maker_de_sp, rtol=tol, atol=tol)
|
||||
self._CompileAndCheckSparse(operator.mul, args_maker_sp_de, rtol=tol, atol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, lhs_n_batch=lhs_n_batch,
|
||||
@ -2173,20 +2146,16 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
rhs_dtype=all_dtypes,
|
||||
)
|
||||
def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n_batch, rhs_n_batch, n_dense):
|
||||
rng = rand_sparse(self.rng())
|
||||
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
|
||||
rhs = jnp.array(rng(rhs_shape, rhs_dtype))
|
||||
|
||||
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=lhs_n_batch, n_dense=n_dense)
|
||||
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=rhs_n_batch, n_dense=n_dense)
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
out1 = lhs * rhs
|
||||
out2 = (lhs_sp * rhs_sp).todense()
|
||||
sprng = sptu.rand_bcoo(self.rng(), n_dense=n_dense)
|
||||
args_maker = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=lhs_n_batch),
|
||||
sprng(rhs_shape, rhs_dtype, n_batch=rhs_n_batch)]
|
||||
|
||||
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
||||
np.float32: 1E-6, np.complex64: 1E-6}
|
||||
self.assertAllClose(out1, out2, rtol=tol)
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol)
|
||||
self._CompileAndCheckSparse(operator.mul, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
def test_bcoo_mul_sparse_with_duplicates(self):
|
||||
# Regression test for https://github.com/google/jax/issues/8888
|
||||
|
Loading…
x
Reference in New Issue
Block a user