[sparse] add BCOO lowering for div

We had avoiding this previously because dividing by zero is
a densifying operation, but we already support mul which has
similar issues if the operand contains infinities.
This commit is contained in:
Jake VanderPlas 2023-03-14 11:58:43 -07:00
parent ed8ddfb3f7
commit 74242f06d9
2 changed files with 45 additions and 0 deletions

View File

@ -671,6 +671,22 @@ def _mul_sparse(spenv, *spvalues):
sparse_rules_bcoo[lax.mul_p] = _mul_sparse
def _div_sparse(spenv, *spvalues):
X, Y = spvalues
if Y.is_sparse():
raise NotImplementedError(
"Division by a sparse array is not implemented because it "
"would result in dense output. If this is your intent, use "
"sparse.todense() to convert your arguments to a dense array.")
X_promoted = spvalues_to_arrays(spenv, X)
out_data = bcoo_multiply_dense(X_promoted, 1. / spenv.data(Y))
out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref,
indices_sorted=X.indices_sorted,
unique_indices=X.unique_indices)
return (out_spvalue,)
sparse_rules_bcoo[lax.div_p] = _div_sparse
def _reduce_sum_sparse(spenv, *spvalues, axes):
X, = spvalues
X_promoted = spvalues_to_arrays(spenv, X)
@ -894,6 +910,8 @@ _bcoo_methods = {
"__rmatmul__": sparsify(_swap_args(jnp.matmul)),
"__mul__": sparsify(jnp.multiply),
"__rmul__": sparsify(_swap_args(jnp.multiply)),
"__truediv__": sparsify(jnp.divide),
"__rtruediv__": sparsify(_swap_args(jnp.divide)),
"__add__": sparsify(jnp.add),
"__radd__": sparsify(_swap_args(jnp.add)),
"__sub__": sparsify(jnp.subtract),

View File

@ -240,6 +240,33 @@ class SparsifyTest(jtu.JaxTestCase):
self.assertAllClose(out.todense(), x.todense() * y.todense())
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)
],
dtype=jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex,
)
def testSparseDiv(self, shape, dtype, n_batch, n_dense):
rng_dense = jtu.rand_nonzero(self.rng())
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch,
n_dense=n_dense)
spdiv = self.sparsify(operator.truediv)
# Scalar division
divisor = 2
expected = x.todense() / divisor
self.assertAllClose(expected, spdiv(x, divisor).todense())
self.assertAllClose(expected, (x / divisor).todense())
# Array division
divisor = rng_dense(shape, dtype)
expected = x.todense() / divisor
self.assertAllClose(expected, spdiv(x, divisor).todense())
self.assertAllClose(expected, (x / divisor).todense())
def testSparseSubtract(self):
x = BCOO.fromdense(3 * jnp.arange(5))
y = BCOO.fromdense(jnp.arange(5))