Previously, using sparse.todense on a BCSR matrix with sparse.sparsify would raise NotImplementedError: sparse rule for todense is not implemented. By adding the sparse rule, it will resolve this issue.

PiperOrigin-RevId: 551291543
This commit is contained in:
jax authors 2023-07-26 13:00:24 -07:00
parent 416814df2a
commit 735637e313
2 changed files with 16 additions and 5 deletions

View File

@ -858,6 +858,7 @@ def _todense_sparse_rule(spenv, spvalue, *, tree):
return (spenv.dense(out),)
sparse_rules_bcoo[sparse.todense_p] = _todense_sparse_rule
sparse_rules_bcsr[sparse.todense_p] = _todense_sparse_rule
def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
call_jaxpr = params.pop('call_jaxpr')

View File

@ -25,7 +25,7 @@ import jax
from jax import config, jit, lax
import jax.numpy as jnp
import jax._src.test_util as jtu
from jax.experimental.sparse import BCOO, sparsify, todense, SparseTracer
from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer
from jax.experimental.sparse.transform import (
arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv)
from jax.experimental.sparse.util import CuSparseEfficiencyWarning
@ -224,7 +224,6 @@ class SparsifyTest(jtu.JaxTestCase):
with self.assertRaisesRegex(NotImplementedError, msg):
self.sparsify(operator.add)(x, 1.)
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
@ -564,12 +563,23 @@ class SparsifyTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"):
func(x_bcoo, y)
def testToDense(self):
M = jnp.arange(4)
@parameterized.named_parameters(
{"testcase_name": f"_{fmt}", "fmt": fmt}
for fmt in ["BCSR", "BCOO"]
)
def testToDense(self, fmt):
M = jnp.arange(4).reshape(2, 2)
if fmt == "BCOO":
Msp = BCOO.fromdense(M)
elif fmt == "BCSR":
Msp = BCSR.fromdense(M)
else:
raise ValueError(f"Unrecognized {fmt=}")
@self.sparsify
def func(M):
return todense(M) + 1
self.assertArraysEqual(func(M), M + 1)
self.assertArraysEqual(func(Msp), M + 1)
self.assertArraysEqual(jit(func)(M), M + 1)