mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] refactor tests to improve runtime
This commit is contained in:
parent
cd5b26a0b9
commit
b00890b036
@ -44,51 +44,24 @@ class SparseTestCase(jtu.JaxTestCase):
|
||||
self.assertAllClose(x_bufs, y_bufs, check_dtypes=check_dtypes, atol=atol, rtol=rtol,
|
||||
canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg)
|
||||
|
||||
def _CheckAgainstDense(self, dense_op, sparse_op, args_maker,
|
||||
def _CheckAgainstDense(self, dense_op, sparse_op, args_maker, check_jit=True,
|
||||
check_dtypes=True, tol=None, atol=None, rtol=None,
|
||||
canonicalize_dtypes=True):
|
||||
"""Check an operation against a dense equivalent"""
|
||||
sparse_args = args_maker()
|
||||
dense_args = tree_util.tree_map(sparse.todense, sparse_args, is_leaf=is_sparse)
|
||||
expected = dense_op(*dense_args)
|
||||
|
||||
sparse_ans = sparse_op(*sparse_args)
|
||||
actual = tree_util.tree_map(sparse.todense, sparse_ans, is_leaf=is_sparse)
|
||||
expected = dense_op(*dense_args)
|
||||
|
||||
self.assertAllClose(expected, actual, check_dtypes=check_dtypes,
|
||||
atol=atol or tol, rtol=rtol or tol,
|
||||
canonicalize_dtypes=canonicalize_dtypes)
|
||||
|
||||
def _CompileAndCheckSparse(self, fun, args_maker, check_dtypes=True,
|
||||
rtol=None, atol=None, check_cache_misses=True):
|
||||
args = args_maker()
|
||||
|
||||
def wrapped_fun(*args):
|
||||
self.assertTrue(python_should_be_executing)
|
||||
return fun(*args)
|
||||
|
||||
python_should_be_executing = True
|
||||
python_ans = fun(*args)
|
||||
|
||||
cache_misses = dispatch.xla_primitive_callable.cache_info().misses
|
||||
python_ans = fun(*args)
|
||||
if check_cache_misses:
|
||||
self.assertEqual(
|
||||
cache_misses, dispatch.xla_primitive_callable.cache_info().misses,
|
||||
"Compilation detected during second call of {} in op-by-op "
|
||||
"mode.".format(fun))
|
||||
|
||||
cfun = api.jit(wrapped_fun)
|
||||
python_should_be_executing = True
|
||||
monitored_ans = cfun(*args)
|
||||
|
||||
python_should_be_executing = False
|
||||
compiled_ans = cfun(*args)
|
||||
|
||||
self.assertSparseArraysEquivalent(python_ans, monitored_ans, check_dtypes=check_dtypes,
|
||||
atol=atol, rtol=rtol)
|
||||
self.assertSparseArraysEquivalent(python_ans, compiled_ans, check_dtypes=check_dtypes,
|
||||
atol=atol, rtol=rtol)
|
||||
if check_jit:
|
||||
sparse_ans_jit = jax.jit(sparse_op)(*sparse_args)
|
||||
self.assertSparseArraysEquivalent(sparse_ans, sparse_ans_jit,
|
||||
atol=atol or tol, rtol=rtol or tol)
|
||||
|
||||
def _CheckGradsSparse(self, dense_fun, sparse_fun, args_maker, *,
|
||||
argnums=None, modes=('fwd', 'rev'), atol=None, rtol=None):
|
||||
@ -125,7 +98,7 @@ class SparseTestCase(jtu.JaxTestCase):
|
||||
return [rng.randint(0, arg + 1) for arg in args]
|
||||
|
||||
def _CheckBatchingSparse(self, dense_fun, sparse_fun, args_maker, *, batch_size=3, bdims=None,
|
||||
check_dtypes=True, tol=None, atol=None, rtol=None,
|
||||
check_jit=False, check_dtypes=True, tol=None, atol=None, rtol=None,
|
||||
canonicalize_dtypes=True):
|
||||
if bdims is None:
|
||||
bdims = self._random_bdims(*(arg.n_batch if is_sparse(arg) else arg.ndim
|
||||
@ -139,7 +112,7 @@ class SparseTestCase(jtu.JaxTestCase):
|
||||
return [arg[0] if bdim is None else concat([expand(x, bdim) for x in arg], bdim)
|
||||
for arg, bdim in safe_zip(args, bdims)]
|
||||
self._CheckAgainstDense(jax.vmap(dense_fun, bdims), jax.vmap(sparse_fun, bdims), batched_args_maker,
|
||||
check_dtypes=check_dtypes, tol=tol, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes, tol=tol, atol=atol, rtol=rtol, check_jit=check_jit,
|
||||
canonicalize_dtypes=canonicalize_dtypes)
|
||||
|
||||
def _rand_sparse(shape: Sequence[int], dtype: DTypeLike, *,
|
||||
|
@ -820,7 +820,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
sparse_op = partial(sparse.bcoo_extract, assume_unique=assume_unique)
|
||||
|
||||
self._CheckAgainstDense(dense_op, sparse_op, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_op, args_maker)
|
||||
self._CheckBatchingSparse(dense_op, sparse_op, args_maker, bdims=2 * self._random_bdims(n_batch))
|
||||
|
||||
def test_bcoo_extract_duplicate_indices(self):
|
||||
@ -920,7 +919,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
sparse_func = partial(sparse.bcoo_transpose, permutation=permutation)
|
||||
|
||||
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_func, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_func, sparse_func, args_maker)
|
||||
self._CheckBatchingSparse(dense_func, sparse_func, args_maker, bdims=self._random_bdims(n_batch))
|
||||
@ -979,7 +977,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
tol = {np.float64: 1E-12, np.complex128: 1E-12,
|
||||
np.float32: 1E-5, np.complex64: 1E-5}
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheckSparse(sparse_fun, args_maker, atol=tol, rtol=tol)
|
||||
if jnp.issubdtype(dtype, jnp.floating) and props.n_dense == 0:
|
||||
# Dense dimensions not yet fully supported in reverse mode.
|
||||
modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev']
|
||||
@ -1195,7 +1192,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
tol = {np.float64: 1E-12, np.complex128: 1E-12,
|
||||
np.float32: 1E-5, np.complex64: 1E-5}
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheckSparse(sparse_fun, args_maker, atol=tol, rtol=tol)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
# Dense dimensions not yet fully supported in reverse mode.
|
||||
modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev']
|
||||
@ -1260,8 +1256,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
return sparse.bcoo_dot_general_sampled(
|
||||
lhs, rhs, indices, dimension_numbers=props.dimension_numbers)
|
||||
|
||||
self._CheckAgainstNumpy(dense_fun, sparse_fun, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_fun, args_maker)
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
# Note: forward mode fails for some sparse layouts.
|
||||
# TODO(jakevdp) fix forward-mode autodiff & enable tests here.
|
||||
@ -1375,7 +1370,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
else:
|
||||
tol = {"float32": 1E-5, "complex64": 1E-5, "float64": 1E-14, "complex128": 1E-14}
|
||||
self._CheckAgainstDense(f_dense, f_sparse, args_maker, tol=tol)
|
||||
self._CompileAndCheckSparse(f_sparse, args_maker)
|
||||
self._CheckBatchingSparse(f_dense, f_sparse, args_maker, tol=tol)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(f_dense, f_sparse, args_maker, modes=['fwd'])
|
||||
@ -1446,7 +1440,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
sparse_func = partial(sparse.bcoo_slice, **kwds)
|
||||
|
||||
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_func, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_func, sparse_func, args_maker)
|
||||
|
||||
@ -1482,7 +1475,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
sparse_func = partial(sparse.bcoo_dynamic_slice, **kwds)
|
||||
|
||||
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_func, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_func, sparse_func, args_maker)
|
||||
|
||||
@ -1529,7 +1521,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
fun = lambda x: x[idx]
|
||||
|
||||
self._CheckAgainstDense(fun, fun, args_maker)
|
||||
self._CompileAndCheckSparse(fun, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(fun, fun, args_maker)
|
||||
|
||||
@ -1546,7 +1537,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
args_maker = lambda: [sprng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstDense(list, list, args_maker)
|
||||
self._CompileAndCheckSparse(list, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, nse=nse)
|
||||
@ -1575,12 +1565,10 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
if nse:
|
||||
self.assertEqual(out.nse, nse)
|
||||
return out
|
||||
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker)
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, check_jit=(nse is not None))
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker)
|
||||
if nse is not None:
|
||||
self._CompileAndCheckSparse(sparse_fun, args_maker)
|
||||
self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -1721,7 +1709,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
dense_fun = partial(lambda x: x.sum(axes))
|
||||
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_fun, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker)
|
||||
|
||||
@ -1746,7 +1733,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
sparse_func = partial(sparse.bcoo_squeeze, dimensions=dimensions)
|
||||
|
||||
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_func, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_func, sparse_func, args_maker)
|
||||
|
||||
@ -1783,7 +1769,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
dense_func = partial(lax.reshape, new_sizes=new_sizes, dimensions=dimensions)
|
||||
|
||||
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_func, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_func, sparse_func, args_maker)
|
||||
|
||||
@ -1838,9 +1823,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_de_sp, tol=tol)
|
||||
self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_sp_de, tol=tol)
|
||||
|
||||
self._CompileAndCheckSparse(operator.matmul, args_maker_de_sp, rtol=tol, atol=tol)
|
||||
self._CompileAndCheckSparse(operator.matmul, args_maker_sp_de, rtol=tol, atol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, n_batch=layout.n_batch,
|
||||
n_dense=layout.n_dense)
|
||||
@ -1867,9 +1849,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self._CheckAgainstDense(operator.mul, operator.mul, args_maker_de_sp, tol=tol)
|
||||
self._CheckAgainstDense(operator.mul, operator.mul, args_maker_sp_de, tol=tol)
|
||||
|
||||
self._CompileAndCheckSparse(operator.mul, args_maker_de_sp, rtol=tol, atol=tol)
|
||||
self._CompileAndCheckSparse(operator.mul, args_maker_sp_de, rtol=tol, atol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, lhs_n_batch=lhs_n_batch,
|
||||
rhs_n_batch=rhs_n_batch, n_dense=n_dense)
|
||||
@ -1896,7 +1875,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol)
|
||||
self._CompileAndCheckSparse(operator.mul, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
def test_bcoo_mul_sparse_with_duplicates(self):
|
||||
# Regression test for https://github.com/google/jax/issues/8888
|
||||
@ -1944,7 +1922,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
sparse_func = partial(sparse.bcoo_concatenate, dimension=dimension)
|
||||
|
||||
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_func, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_func, sparse_func, args_maker)
|
||||
|
||||
@ -2152,7 +2129,6 @@ class BCSRTest(sptu.SparseTestCase):
|
||||
np.float32: 1E-5, np.complex64: 1E-5}
|
||||
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheckSparse(sparse_fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
class SparseGradTest(sptu.SparseTestCase):
|
||||
@jtu.sample_product(has_aux=[True, False])
|
||||
|
Loading…
x
Reference in New Issue
Block a user