[sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.

PiperOrigin-RevId: 510248091
This commit is contained in:
Tianjian Lu 2023-02-16 14:39:37 -08:00 committed by jax authors
parent c368562529
commit 4fa69e60a0
2 changed files with 26 additions and 13 deletions

View File

@ -960,6 +960,13 @@ def _bcoo_dot_general_gpu_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
# BCOO allows padded sparse representations to comply with the requirement
# of static array size. However, paddings create out-of-bound (OOB) indices
# which are not supported in any cuSparse routines.
# Set the OOB indices and their correpsonding data to zeros as a correction.
(lhs_data,), (lhs_indices,) = _bcoo_correct_out_of_bound_indices_lowered(
ctx, lhs_data, lhs_indices, rhs, shape=lhs_spinfo.shape)
return _bcoo_dot_general_cuda_lowering(
coo_matvec_lowering, coo_matmat_lowering, ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
@ -1487,22 +1494,27 @@ def _unique_indices_unbatched(indices, *, shape, return_inverse=False,
out = (*out[:-1], nse)
return out
def _fix_oob_indices_unbatched(data, indices, *, shape):
"""Set out-of-bound (OOB) indices and the corresponding data to zero."""
def _bcoo_correct_out_of_bound_indices_unbatched(data, indices, rhs, *, shape):
"""Set out-of-bound (OOB) indices and the corresponding data to zeros."""
del rhs
n_batch, n_sparse, n_dense, _ = _validate_bcoo_indices(indices, shape)
assert n_dense == 0, "not implemented"
assert n_batch == 0
mask = indices >= jnp.array(shape[:n_sparse])[None, :]
mask = indices >= jnp.array(shape[:n_sparse], dtype=indices.dtype)[None, :]
new_indices = jnp.where(mask, 0, indices)
new_data = jnp.where(mask.any(-1), 0, data)
return new_data, new_indices
def _fix_oob_indices(data, indices, *, spinfo):
"""Set out-of-bound (OOB) indices and the corresponding data to zero."""
props = _validate_bcoo_indices(indices, spinfo.shape)
f = partial(_fix_oob_indices_unbatched, shape=spinfo.shape[props.n_batch:])
def _bcoo_correct_out_of_bound_indices(data, indices, rhs, *, shape):
"""Set out-of-bound (OOB) indices and the corresponding data to zeros."""
props = _validate_bcoo_indices(indices, shape)
f = partial(_bcoo_correct_out_of_bound_indices_unbatched,
shape=shape[props.n_batch:])
f = nfold_vmap(f, props.n_batch)
return f(data, indices)
return f(data, indices, rhs)
_bcoo_correct_out_of_bound_indices_lowered = mlir.lower_fun(
_bcoo_correct_out_of_bound_indices, multiple_results=True)
@bcoo_sum_duplicates_p.def_abstract_eval
def _bcoo_sum_duplicates_abstract_eval(data, indices, *, spinfo, nse):

View File

@ -1690,12 +1690,13 @@ class BCOOTest(sptu.SparseTestCase):
self.assertArraysEqual(x.indices, y.indices)
self.assertArraysEqual(x.data, y.data)
def test_bcoo_fix_oob_indices(self):
def test_bcoo_correct_out_of_bound_indices(self):
data = jnp.array([1, 2, 3, 4, 0, 0])
indices = jnp.array([1, 3, 0, 2, 2, 4])[:, None]
x1 = sparse.BCOO((data, indices), shape=(2,))
data_unbatched, indices_unbatched = sparse_bcoo._fix_oob_indices(
x1.data, x1.indices, spinfo=x1._info)
rhs = jnp.zeros((2), dtype=data.dtype)
data_unbatched, indices_unbatched = sparse_bcoo._bcoo_correct_out_of_bound_indices(
x1.data, x1.indices, rhs, shape=x1._info.shape)
expected_data_unbatched = jnp.array([1, 0, 3, 0, 0, 0])
expected_indices_unbatched = jnp.array([[1], [0], [0], [0], [0], [0]])
with self.subTest('unbatched data'):
@ -1708,8 +1709,8 @@ class BCOOTest(sptu.SparseTestCase):
indices = jnp.array([[[0, 0], [1, 1], [3, 4], [4, 5]],
[[2, 1], [1, 0], [2, 3], [1, 1]]])
x2 = sparse.BCOO((data, indices), shape=(2, 3, 2))
data_batched, indices_batched = sparse_bcoo._fix_oob_indices(
x2.data, x2.indices, spinfo=x2._info)
data_batched, indices_batched = sparse_bcoo._bcoo_correct_out_of_bound_indices(
x2.data, x2.indices, rhs, shape=x2._info.shape)
expected_data_batched = jnp.array([[0, 1, 0, 0],
[4, 5, 0, 7]])
expected_indices_batched = jnp.array(