From 735637e31376259c2ee7a12e78fa66d93e021f50 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 26 Jul 2023 13:00:24 -0700 Subject: [PATCH] Previously, using sparse.todense on a BCSR matrix with sparse.sparsify would raise `NotImplementedError: sparse rule for todense is not implemented`. By adding the sparse rule, it will resolve this issue. PiperOrigin-RevId: 551291543 --- jax/experimental/sparse/transform.py | 1 + tests/sparsify_test.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 987edcdea..1d3d910a7 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -858,6 +858,7 @@ def _todense_sparse_rule(spenv, spvalue, *, tree): return (spenv.dense(out),) sparse_rules_bcoo[sparse.todense_p] = _todense_sparse_rule +sparse_rules_bcsr[sparse.todense_p] = _todense_sparse_rule def _custom_jvp_sparse_rule(spenv, *spvalues, **params): call_jaxpr = params.pop('call_jaxpr') diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 9d6fc530d..56626ec44 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -25,7 +25,7 @@ import jax from jax import config, jit, lax import jax.numpy as jnp import jax._src.test_util as jtu -from jax.experimental.sparse import BCOO, sparsify, todense, SparseTracer +from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer from jax.experimental.sparse.transform import ( arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv) from jax.experimental.sparse.util import CuSparseEfficiencyWarning @@ -224,7 +224,6 @@ class SparsifyTest(jtu.JaxTestCase): with self.assertRaisesRegex(NotImplementedError, msg): self.sparsify(operator.add)(x, 1.) - @jtu.sample_product( [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] @@ -564,12 +563,23 @@ class SparsifyTest(jtu.JaxTestCase): with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"): func(x_bcoo, y) - def testToDense(self): - M = jnp.arange(4) - Msp = BCOO.fromdense(M) + @parameterized.named_parameters( + {"testcase_name": f"_{fmt}", "fmt": fmt} + for fmt in ["BCSR", "BCOO"] + ) + def testToDense(self, fmt): + M = jnp.arange(4).reshape(2, 2) + if fmt == "BCOO": + Msp = BCOO.fromdense(M) + elif fmt == "BCSR": + Msp = BCSR.fromdense(M) + else: + raise ValueError(f"Unrecognized {fmt=}") + @self.sparsify def func(M): return todense(M) + 1 + self.assertArraysEqual(func(M), M + 1) self.assertArraysEqual(func(Msp), M + 1) self.assertArraysEqual(jit(func)(M), M + 1)