mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Previously, using sparse.todense on a BCSR matrix with sparse.sparsify would raise NotImplementedError: sparse rule for todense is not implemented
. By adding the sparse rule, it will resolve this issue.
PiperOrigin-RevId: 551291543
This commit is contained in:
parent
416814df2a
commit
735637e313
@ -858,6 +858,7 @@ def _todense_sparse_rule(spenv, spvalue, *, tree):
|
||||
return (spenv.dense(out),)
|
||||
|
||||
sparse_rules_bcoo[sparse.todense_p] = _todense_sparse_rule
|
||||
sparse_rules_bcsr[sparse.todense_p] = _todense_sparse_rule
|
||||
|
||||
def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
|
||||
call_jaxpr = params.pop('call_jaxpr')
|
||||
|
@ -25,7 +25,7 @@ import jax
|
||||
from jax import config, jit, lax
|
||||
import jax.numpy as jnp
|
||||
import jax._src.test_util as jtu
|
||||
from jax.experimental.sparse import BCOO, sparsify, todense, SparseTracer
|
||||
from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer
|
||||
from jax.experimental.sparse.transform import (
|
||||
arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv)
|
||||
from jax.experimental.sparse.util import CuSparseEfficiencyWarning
|
||||
@ -224,7 +224,6 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(NotImplementedError, msg):
|
||||
self.sparsify(operator.add)(x, 1.)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
@ -564,12 +563,23 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"):
|
||||
func(x_bcoo, y)
|
||||
|
||||
def testToDense(self):
|
||||
M = jnp.arange(4)
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{fmt}", "fmt": fmt}
|
||||
for fmt in ["BCSR", "BCOO"]
|
||||
)
|
||||
def testToDense(self, fmt):
|
||||
M = jnp.arange(4).reshape(2, 2)
|
||||
if fmt == "BCOO":
|
||||
Msp = BCOO.fromdense(M)
|
||||
elif fmt == "BCSR":
|
||||
Msp = BCSR.fromdense(M)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized {fmt=}")
|
||||
|
||||
@self.sparsify
|
||||
def func(M):
|
||||
return todense(M) + 1
|
||||
|
||||
self.assertArraysEqual(func(M), M + 1)
|
||||
self.assertArraysEqual(func(Msp), M + 1)
|
||||
self.assertArraysEqual(jit(func)(M), M + 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user