[sparse] improve worst-case nse in spdot_general

This commit is contained in:
Jake VanderPlas 2023-03-07 18:24:12 -08:00
parent d62fc88fb1
commit 4180f8bf7b
2 changed files with 33 additions and 18 deletions

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import functools
from functools import partial
import math
import operator
from typing import Any, List, NamedTuple, Optional, Protocol, Sequence, Tuple, Union
import warnings
@ -1083,7 +1084,7 @@ def _bcoo_spdot_general(lhs_data: Array, lhs_indices: Array, rhs_data: Array, rh
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo,
dimension_numbers=(cdims, bdims))
def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo, rhs_spinfo, lhs_contracting, rhs_contracting):
def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo, rhs_spinfo, lhs_contracting, rhs_contracting, out_nse):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
@ -1125,7 +1126,6 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
out_indices = out_indices.at[:, :, :lhs_j.shape[-1]].set(lhs_j[:, None])
out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :])
out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1])
out_nse = (lhs.nse if lhs_j.shape[1] else 1) * (rhs.nse if rhs_j.shape[1] else 1)
# Note: we do not eliminate zeros here, because it can cause issues with autodiff.
# See https://github.com/google/jax/issues/10163.
return _bcoo_sum_duplicates(out_data, out_indices, spinfo=SparseInfo(shape=out_shape), nse=out_nse)
@ -1138,12 +1138,10 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
assert lhs.n_dense == rhs.n_dense == 0
data_aval, indices_aval = _bcoo_spdot_general_abstract_eval(
data_aval, _ = _bcoo_spdot_general_abstract_eval(
lhs_data.aval, lhs_indices.aval, rhs_data.aval, rhs_indices.aval,
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo, dimension_numbers=dimension_numbers)
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape,
dimension_numbers)
_validate_bcoo(data_aval, indices_aval, out_shape)
out_nse = data_aval.shape[-1]
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
@ -1160,7 +1158,8 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
lhs_spinfo=SparseInfo(lhs_shape[lhs.n_batch:]),
rhs_spinfo=SparseInfo(rhs_shape[rhs.n_batch:]),
lhs_contracting=[d - lhs.n_batch for d in lhs_contracting],
rhs_contracting=[d - rhs.n_batch for d in rhs_contracting])
rhs_contracting=[d - rhs.n_batch for d in rhs_contracting],
out_nse=out_nse)
func = nfold_vmap(func, rhs.n_batch - len(rhs_batch), in_axes=(None, None, 0, 0))
func = nfold_vmap(func, lhs.n_batch - len(lhs_batch), in_axes=(0, 0, None, None))
@ -1171,6 +1170,8 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape,
dimension_numbers)
if lhs_data.dtype != rhs_data.dtype:
raise ValueError("bcoo_spdot_general requires inputs to have matching dtypes; "
@ -1201,6 +1202,11 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
(rhs.nse if rhs.n_sparse > len(rhs_contracting) else 1)
)
# Ensure we're not storing more output elements than necessary.
# TODO(jakevdp): should we warn here if output is effectively dense?
out_n_batch = lhs.n_batch + rhs.n_batch - len(lhs_batch)
out_nse = min(out_nse, math.prod(out_shape[out_n_batch:]))
data_shape = (
*(lhs_shape[dim] for dim in lhs_batch),
*(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
@ -1211,7 +1217,12 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
*(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
*(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
out_nse, lhs.n_sparse + rhs.n_sparse - 2 * len(lhs_contracting))
return core.ShapedArray(data_shape, lhs_data.dtype), core.ShapedArray(indices_shape, lhs_indices.dtype)
data_aval = core.ShapedArray(data_shape, lhs_data.dtype)
indices_aval = core.ShapedArray(indices_shape, lhs_indices.dtype)
_validate_bcoo(data_aval, indices_aval, out_shape)
return data_aval, indices_aval
def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
lhs_ndim = len(lhs_spinfo.shape)

View File

@ -1472,18 +1472,22 @@ class BCOOTest(sptu.SparseTestCase):
if jnp.issubdtype(dtype, jnp.floating):
self._CheckGradsSparse(f_dense, f_sparse, args_maker, modes=['fwd'])
def test_bcoo_spdot_general_nse(self):
# vector-vector product -> nse=1
x = sparse.BCOO.fromdense(jnp.arange(3))
self.assertEqual((x @ x).nse, 1)
@jtu.sample_product(
lhs_shape=[(5,), (4, 5)],
rhs_shape=[(5,), (5, 4)])
@jax.default_matmul_precision("float32")
def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape):
rng = sptu.rand_bcoo(self.rng())
dtype = jnp.float32
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
out = lhs @ rhs
# matrix-vector product -> nse matches matrix
M = sparse.BCOO.fromdense(jnp.arange(6).reshape(2, 3))
self.assertEqual((M @ x).nse, M.nse)
expected_out = lhs.todense() @ rhs.todense()
expected_nse = min(lhs.nse * rhs.nse, out.size)
# matrix-matrix product -> product of nse
N = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4))
self.assertEqual((M @ N).nse, M.nse * N.nse)
self.assertArraysAllClose(out.todense(), expected_out)
self.assertEqual(out.nse, expected_nse)
def test_bcoo_spdot_general_ad_bug(self):
# Regression test for https://github.com/google/jax/issues/10163