[sparse] fix autodiff bug in spdot_general

This commit is contained in:
Jake VanderPlas 2022-04-08 11:04:26 -07:00
parent 8b6b736ef3
commit 8b9efe79e7
2 changed files with 38 additions and 2 deletions

View File

@ -84,6 +84,8 @@ def _bcoo_nse(mat, n_batch=0, n_dense=0):
mask = mask.sum(list(range(n_batch, mask.ndim)))
return mask.max()
# TODO(jakevdp): add a custom autodiff rule that errors if remove_zeros=True, because
# it produces wrong values. See https://github.com/google/jax/issues/10163
def _bcoo_sum_duplicates(data, indices, shape, nse=None, remove_zeros=True):
if nse is None and isinstance(jnp.array(0), core.Tracer):
raise ValueError("When used with JIT, vmap, or another transform, sum_duplicates() "
@ -1229,7 +1231,8 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :])
out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1])
out_nse = (lhs.nse if lhs_j.shape[1] else 1) * (rhs.nse if rhs_j.shape[1] else 1)
return _bcoo_sum_duplicates(out_data, out_indices, out_shape, nse=out_nse)
# Note: remove_zeros=True is incompatible with autodiff.
return _bcoo_sum_duplicates(out_data, out_indices, out_shape, nse=out_nse, remove_zeros=False)
@bcoo_spdot_general_p.def_impl
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers):
@ -1770,7 +1773,8 @@ class BCOO(JAXSparse):
If it is smaller than the number required, data will be silently discarded.
remove_zeros : bool (default=True). If True, remove explicit zeros from the data
as part of summing duplicates. If False, then explicit zeros at unique indices
will remain among the specified elements.
will remain among the specified elements. Note: remove_zeros=True is incompatible
with autodiff.
"""
data, indices = _bcoo_sum_duplicates(self.data, self.indices, self.shape,
nse=nse, remove_zeros=remove_zeros)

View File

@ -1454,6 +1454,38 @@ class BCOOTest(jtu.JaxTestCase):
self.assertAllClose(jf_dense_0, jf_sparse_0, rtol=tol)
self.assertAllClose(jf_dense_1, jf_sparse_1, rtol=tol)
def test_bcoo_spdot_general_ad_bug(self):
# Regression test for https://github.com/google/jax/issues/10163
A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]])
A_values = jnp.array([-2.0, 1.0, -1.0, 0.5, 2.0])
A_shape = (2, 3)
B_indices = jnp.array([[0, 2], [2, 1], [0, 3], [1, 3], [1, 0], [0, 0]])
B_values = jnp.array([10.0, 100.0, 1000.0, -5.0, -50.0, -500.0])
B_shape = (3, 4)
def sp_sp_product(v1, v2):
A = sparse.BCOO((v1, A_indices), shape=A_shape)
B = sparse.BCOO((v2, B_indices), shape=B_shape)
return (A @ B).todense()
def sp_de_product(v1, v2):
A = sparse.BCOO((v1, A_indices), shape=A_shape)
B = sparse.BCOO((v2, B_indices), shape=B_shape).todense()
return A @ B
def de_de_product(v1, v2):
sparse1 = sparse.BCOO((v1, A_indices), shape=A_shape).todense()
dense2 = sparse.BCOO((v2, B_indices), shape=B_shape).todense()
return sparse1 @ dense2
sp_sp_jac = jax.jacfwd(sp_sp_product, argnums=1)(A_values, B_values)
sp_de_jac = jax.jacfwd(sp_de_product, argnums=1)(A_values, B_values)
de_de_jac = jax.jacfwd(de_de_product, argnums=1)(A_values, B_values)
self.assertAllClose(sp_sp_jac, de_de_jac)
self.assertAllClose(sp_de_jac, de_de_jac)
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_in_axes={}".format(