mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[sparse] improve worst-case nse in spdot_general
This commit is contained in:
parent
d62fc88fb1
commit
4180f8bf7b
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import Any, List, NamedTuple, Optional, Protocol, Sequence, Tuple, Union
|
||||
import warnings
|
||||
@ -1083,7 +1084,7 @@ def _bcoo_spdot_general(lhs_data: Array, lhs_indices: Array, rhs_data: Array, rh
|
||||
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo,
|
||||
dimension_numbers=(cdims, bdims))
|
||||
|
||||
def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo, rhs_spinfo, lhs_contracting, rhs_contracting):
|
||||
def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo, rhs_spinfo, lhs_contracting, rhs_contracting, out_nse):
|
||||
lhs_shape = lhs_spinfo.shape
|
||||
rhs_shape = rhs_spinfo.shape
|
||||
|
||||
@ -1125,7 +1126,6 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
|
||||
out_indices = out_indices.at[:, :, :lhs_j.shape[-1]].set(lhs_j[:, None])
|
||||
out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :])
|
||||
out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1])
|
||||
out_nse = (lhs.nse if lhs_j.shape[1] else 1) * (rhs.nse if rhs_j.shape[1] else 1)
|
||||
# Note: we do not eliminate zeros here, because it can cause issues with autodiff.
|
||||
# See https://github.com/google/jax/issues/10163.
|
||||
return _bcoo_sum_duplicates(out_data, out_indices, spinfo=SparseInfo(shape=out_shape), nse=out_nse)
|
||||
@ -1138,12 +1138,10 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
|
||||
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
|
||||
assert lhs.n_dense == rhs.n_dense == 0
|
||||
data_aval, indices_aval = _bcoo_spdot_general_abstract_eval(
|
||||
data_aval, _ = _bcoo_spdot_general_abstract_eval(
|
||||
lhs_data.aval, lhs_indices.aval, rhs_data.aval, rhs_indices.aval,
|
||||
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo, dimension_numbers=dimension_numbers)
|
||||
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape,
|
||||
dimension_numbers)
|
||||
_validate_bcoo(data_aval, indices_aval, out_shape)
|
||||
out_nse = data_aval.shape[-1]
|
||||
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
|
||||
@ -1160,7 +1158,8 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
|
||||
lhs_spinfo=SparseInfo(lhs_shape[lhs.n_batch:]),
|
||||
rhs_spinfo=SparseInfo(rhs_shape[rhs.n_batch:]),
|
||||
lhs_contracting=[d - lhs.n_batch for d in lhs_contracting],
|
||||
rhs_contracting=[d - rhs.n_batch for d in rhs_contracting])
|
||||
rhs_contracting=[d - rhs.n_batch for d in rhs_contracting],
|
||||
out_nse=out_nse)
|
||||
|
||||
func = nfold_vmap(func, rhs.n_batch - len(rhs_batch), in_axes=(None, None, 0, 0))
|
||||
func = nfold_vmap(func, lhs.n_batch - len(lhs_batch), in_axes=(0, 0, None, None))
|
||||
@ -1171,6 +1170,8 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
|
||||
def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
|
||||
lhs_shape = lhs_spinfo.shape
|
||||
rhs_shape = rhs_spinfo.shape
|
||||
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape,
|
||||
dimension_numbers)
|
||||
|
||||
if lhs_data.dtype != rhs_data.dtype:
|
||||
raise ValueError("bcoo_spdot_general requires inputs to have matching dtypes; "
|
||||
@ -1201,6 +1202,11 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
|
||||
(rhs.nse if rhs.n_sparse > len(rhs_contracting) else 1)
|
||||
)
|
||||
|
||||
# Ensure we're not storing more output elements than necessary.
|
||||
# TODO(jakevdp): should we warn here if output is effectively dense?
|
||||
out_n_batch = lhs.n_batch + rhs.n_batch - len(lhs_batch)
|
||||
out_nse = min(out_nse, math.prod(out_shape[out_n_batch:]))
|
||||
|
||||
data_shape = (
|
||||
*(lhs_shape[dim] for dim in lhs_batch),
|
||||
*(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
|
||||
@ -1211,7 +1217,12 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
|
||||
*(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
|
||||
*(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
|
||||
out_nse, lhs.n_sparse + rhs.n_sparse - 2 * len(lhs_contracting))
|
||||
return core.ShapedArray(data_shape, lhs_data.dtype), core.ShapedArray(indices_shape, lhs_indices.dtype)
|
||||
|
||||
data_aval = core.ShapedArray(data_shape, lhs_data.dtype)
|
||||
indices_aval = core.ShapedArray(indices_shape, lhs_indices.dtype)
|
||||
_validate_bcoo(data_aval, indices_aval, out_shape)
|
||||
|
||||
return data_aval, indices_aval
|
||||
|
||||
def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
|
||||
lhs_ndim = len(lhs_spinfo.shape)
|
||||
|
@ -1472,18 +1472,22 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(f_dense, f_sparse, args_maker, modes=['fwd'])
|
||||
|
||||
def test_bcoo_spdot_general_nse(self):
|
||||
# vector-vector product -> nse=1
|
||||
x = sparse.BCOO.fromdense(jnp.arange(3))
|
||||
self.assertEqual((x @ x).nse, 1)
|
||||
@jtu.sample_product(
|
||||
lhs_shape=[(5,), (4, 5)],
|
||||
rhs_shape=[(5,), (5, 4)])
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape):
|
||||
rng = sptu.rand_bcoo(self.rng())
|
||||
dtype = jnp.float32
|
||||
lhs = rng(lhs_shape, dtype)
|
||||
rhs = rng(rhs_shape, dtype)
|
||||
out = lhs @ rhs
|
||||
|
||||
# matrix-vector product -> nse matches matrix
|
||||
M = sparse.BCOO.fromdense(jnp.arange(6).reshape(2, 3))
|
||||
self.assertEqual((M @ x).nse, M.nse)
|
||||
expected_out = lhs.todense() @ rhs.todense()
|
||||
expected_nse = min(lhs.nse * rhs.nse, out.size)
|
||||
|
||||
# matrix-matrix product -> product of nse
|
||||
N = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4))
|
||||
self.assertEqual((M @ N).nse, M.nse * N.nse)
|
||||
self.assertArraysAllClose(out.todense(), expected_out)
|
||||
self.assertEqual(out.nse, expected_nse)
|
||||
|
||||
def test_bcoo_spdot_general_ad_bug(self):
|
||||
# Regression test for https://github.com/google/jax/issues/10163
|
||||
|
Loading…
x
Reference in New Issue
Block a user