mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[sparse] add BCOO lowering for div
We had avoiding this previously because dividing by zero is a densifying operation, but we already support mul which has similar issues if the operand contains infinities.
This commit is contained in:
parent
ed8ddfb3f7
commit
74242f06d9
@ -671,6 +671,22 @@ def _mul_sparse(spenv, *spvalues):
|
||||
|
||||
sparse_rules_bcoo[lax.mul_p] = _mul_sparse
|
||||
|
||||
def _div_sparse(spenv, *spvalues):
|
||||
X, Y = spvalues
|
||||
if Y.is_sparse():
|
||||
raise NotImplementedError(
|
||||
"Division by a sparse array is not implemented because it "
|
||||
"would result in dense output. If this is your intent, use "
|
||||
"sparse.todense() to convert your arguments to a dense array.")
|
||||
X_promoted = spvalues_to_arrays(spenv, X)
|
||||
out_data = bcoo_multiply_dense(X_promoted, 1. / spenv.data(Y))
|
||||
out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref,
|
||||
indices_sorted=X.indices_sorted,
|
||||
unique_indices=X.unique_indices)
|
||||
return (out_spvalue,)
|
||||
|
||||
sparse_rules_bcoo[lax.div_p] = _div_sparse
|
||||
|
||||
def _reduce_sum_sparse(spenv, *spvalues, axes):
|
||||
X, = spvalues
|
||||
X_promoted = spvalues_to_arrays(spenv, X)
|
||||
@ -894,6 +910,8 @@ _bcoo_methods = {
|
||||
"__rmatmul__": sparsify(_swap_args(jnp.matmul)),
|
||||
"__mul__": sparsify(jnp.multiply),
|
||||
"__rmul__": sparsify(_swap_args(jnp.multiply)),
|
||||
"__truediv__": sparsify(jnp.divide),
|
||||
"__rtruediv__": sparsify(_swap_args(jnp.divide)),
|
||||
"__add__": sparsify(jnp.add),
|
||||
"__radd__": sparsify(_swap_args(jnp.add)),
|
||||
"__sub__": sparsify(jnp.subtract),
|
||||
|
@ -240,6 +240,33 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(out.todense(), x.todense() * y.todense())
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)
|
||||
],
|
||||
dtype=jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
)
|
||||
def testSparseDiv(self, shape, dtype, n_batch, n_dense):
|
||||
rng_dense = jtu.rand_nonzero(self.rng())
|
||||
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
|
||||
x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch,
|
||||
n_dense=n_dense)
|
||||
spdiv = self.sparsify(operator.truediv)
|
||||
|
||||
# Scalar division
|
||||
divisor = 2
|
||||
expected = x.todense() / divisor
|
||||
self.assertAllClose(expected, spdiv(x, divisor).todense())
|
||||
self.assertAllClose(expected, (x / divisor).todense())
|
||||
|
||||
# Array division
|
||||
divisor = rng_dense(shape, dtype)
|
||||
expected = x.todense() / divisor
|
||||
self.assertAllClose(expected, spdiv(x, divisor).todense())
|
||||
self.assertAllClose(expected, (x / divisor).todense())
|
||||
|
||||
def testSparseSubtract(self):
|
||||
x = BCOO.fromdense(3 * jnp.arange(5))
|
||||
y = BCOO.fromdense(jnp.arange(5))
|
||||
|
Loading…
x
Reference in New Issue
Block a user