mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] fix autodiff bug in spdot_general
This commit is contained in:
parent
8b6b736ef3
commit
8b9efe79e7
@ -84,6 +84,8 @@ def _bcoo_nse(mat, n_batch=0, n_dense=0):
|
||||
mask = mask.sum(list(range(n_batch, mask.ndim)))
|
||||
return mask.max()
|
||||
|
||||
# TODO(jakevdp): add a custom autodiff rule that errors if remove_zeros=True, because
|
||||
# it produces wrong values. See https://github.com/google/jax/issues/10163
|
||||
def _bcoo_sum_duplicates(data, indices, shape, nse=None, remove_zeros=True):
|
||||
if nse is None and isinstance(jnp.array(0), core.Tracer):
|
||||
raise ValueError("When used with JIT, vmap, or another transform, sum_duplicates() "
|
||||
@ -1229,7 +1231,8 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
|
||||
out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :])
|
||||
out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1])
|
||||
out_nse = (lhs.nse if lhs_j.shape[1] else 1) * (rhs.nse if rhs_j.shape[1] else 1)
|
||||
return _bcoo_sum_duplicates(out_data, out_indices, out_shape, nse=out_nse)
|
||||
# Note: remove_zeros=True is incompatible with autodiff.
|
||||
return _bcoo_sum_duplicates(out_data, out_indices, out_shape, nse=out_nse, remove_zeros=False)
|
||||
|
||||
@bcoo_spdot_general_p.def_impl
|
||||
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers):
|
||||
@ -1770,7 +1773,8 @@ class BCOO(JAXSparse):
|
||||
If it is smaller than the number required, data will be silently discarded.
|
||||
remove_zeros : bool (default=True). If True, remove explicit zeros from the data
|
||||
as part of summing duplicates. If False, then explicit zeros at unique indices
|
||||
will remain among the specified elements.
|
||||
will remain among the specified elements. Note: remove_zeros=True is incompatible
|
||||
with autodiff.
|
||||
"""
|
||||
data, indices = _bcoo_sum_duplicates(self.data, self.indices, self.shape,
|
||||
nse=nse, remove_zeros=remove_zeros)
|
||||
|
@ -1454,6 +1454,38 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(jf_dense_0, jf_sparse_0, rtol=tol)
|
||||
self.assertAllClose(jf_dense_1, jf_sparse_1, rtol=tol)
|
||||
|
||||
def test_bcoo_spdot_general_ad_bug(self):
|
||||
# Regression test for https://github.com/google/jax/issues/10163
|
||||
A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]])
|
||||
A_values = jnp.array([-2.0, 1.0, -1.0, 0.5, 2.0])
|
||||
A_shape = (2, 3)
|
||||
|
||||
B_indices = jnp.array([[0, 2], [2, 1], [0, 3], [1, 3], [1, 0], [0, 0]])
|
||||
B_values = jnp.array([10.0, 100.0, 1000.0, -5.0, -50.0, -500.0])
|
||||
B_shape = (3, 4)
|
||||
|
||||
def sp_sp_product(v1, v2):
|
||||
A = sparse.BCOO((v1, A_indices), shape=A_shape)
|
||||
B = sparse.BCOO((v2, B_indices), shape=B_shape)
|
||||
return (A @ B).todense()
|
||||
|
||||
def sp_de_product(v1, v2):
|
||||
A = sparse.BCOO((v1, A_indices), shape=A_shape)
|
||||
B = sparse.BCOO((v2, B_indices), shape=B_shape).todense()
|
||||
return A @ B
|
||||
|
||||
def de_de_product(v1, v2):
|
||||
sparse1 = sparse.BCOO((v1, A_indices), shape=A_shape).todense()
|
||||
dense2 = sparse.BCOO((v2, B_indices), shape=B_shape).todense()
|
||||
return sparse1 @ dense2
|
||||
|
||||
sp_sp_jac = jax.jacfwd(sp_sp_product, argnums=1)(A_values, B_values)
|
||||
sp_de_jac = jax.jacfwd(sp_de_product, argnums=1)(A_values, B_values)
|
||||
de_de_jac = jax.jacfwd(de_de_product, argnums=1)(A_values, B_values)
|
||||
|
||||
self.assertAllClose(sp_sp_jac, de_de_jac)
|
||||
self.assertAllClose(sp_de_jac, de_de_jac)
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_in_axes={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user