mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.
PiperOrigin-RevId: 510248091
This commit is contained in:
parent
c368562529
commit
4fa69e60a0
@ -960,6 +960,13 @@ def _bcoo_dot_general_gpu_lowering(
|
||||
ctx, lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
# BCOO allows padded sparse representations to comply with the requirement
|
||||
# of static array size. However, paddings create out-of-bound (OOB) indices
|
||||
# which are not supported in any cuSparse routines.
|
||||
# Set the OOB indices and their correpsonding data to zeros as a correction.
|
||||
(lhs_data,), (lhs_indices,) = _bcoo_correct_out_of_bound_indices_lowered(
|
||||
ctx, lhs_data, lhs_indices, rhs, shape=lhs_spinfo.shape)
|
||||
|
||||
return _bcoo_dot_general_cuda_lowering(
|
||||
coo_matvec_lowering, coo_matmat_lowering, ctx, lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
@ -1487,22 +1494,27 @@ def _unique_indices_unbatched(indices, *, shape, return_inverse=False,
|
||||
out = (*out[:-1], nse)
|
||||
return out
|
||||
|
||||
def _fix_oob_indices_unbatched(data, indices, *, shape):
|
||||
"""Set out-of-bound (OOB) indices and the corresponding data to zero."""
|
||||
def _bcoo_correct_out_of_bound_indices_unbatched(data, indices, rhs, *, shape):
|
||||
"""Set out-of-bound (OOB) indices and the corresponding data to zeros."""
|
||||
del rhs
|
||||
n_batch, n_sparse, n_dense, _ = _validate_bcoo_indices(indices, shape)
|
||||
assert n_dense == 0, "not implemented"
|
||||
assert n_batch == 0
|
||||
mask = indices >= jnp.array(shape[:n_sparse])[None, :]
|
||||
mask = indices >= jnp.array(shape[:n_sparse], dtype=indices.dtype)[None, :]
|
||||
new_indices = jnp.where(mask, 0, indices)
|
||||
new_data = jnp.where(mask.any(-1), 0, data)
|
||||
return new_data, new_indices
|
||||
|
||||
def _fix_oob_indices(data, indices, *, spinfo):
|
||||
"""Set out-of-bound (OOB) indices and the corresponding data to zero."""
|
||||
props = _validate_bcoo_indices(indices, spinfo.shape)
|
||||
f = partial(_fix_oob_indices_unbatched, shape=spinfo.shape[props.n_batch:])
|
||||
def _bcoo_correct_out_of_bound_indices(data, indices, rhs, *, shape):
|
||||
"""Set out-of-bound (OOB) indices and the corresponding data to zeros."""
|
||||
props = _validate_bcoo_indices(indices, shape)
|
||||
f = partial(_bcoo_correct_out_of_bound_indices_unbatched,
|
||||
shape=shape[props.n_batch:])
|
||||
f = nfold_vmap(f, props.n_batch)
|
||||
return f(data, indices)
|
||||
return f(data, indices, rhs)
|
||||
|
||||
_bcoo_correct_out_of_bound_indices_lowered = mlir.lower_fun(
|
||||
_bcoo_correct_out_of_bound_indices, multiple_results=True)
|
||||
|
||||
@bcoo_sum_duplicates_p.def_abstract_eval
|
||||
def _bcoo_sum_duplicates_abstract_eval(data, indices, *, spinfo, nse):
|
||||
|
@ -1690,12 +1690,13 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self.assertArraysEqual(x.indices, y.indices)
|
||||
self.assertArraysEqual(x.data, y.data)
|
||||
|
||||
def test_bcoo_fix_oob_indices(self):
|
||||
def test_bcoo_correct_out_of_bound_indices(self):
|
||||
data = jnp.array([1, 2, 3, 4, 0, 0])
|
||||
indices = jnp.array([1, 3, 0, 2, 2, 4])[:, None]
|
||||
x1 = sparse.BCOO((data, indices), shape=(2,))
|
||||
data_unbatched, indices_unbatched = sparse_bcoo._fix_oob_indices(
|
||||
x1.data, x1.indices, spinfo=x1._info)
|
||||
rhs = jnp.zeros((2), dtype=data.dtype)
|
||||
data_unbatched, indices_unbatched = sparse_bcoo._bcoo_correct_out_of_bound_indices(
|
||||
x1.data, x1.indices, rhs, shape=x1._info.shape)
|
||||
expected_data_unbatched = jnp.array([1, 0, 3, 0, 0, 0])
|
||||
expected_indices_unbatched = jnp.array([[1], [0], [0], [0], [0], [0]])
|
||||
with self.subTest('unbatched data'):
|
||||
@ -1708,8 +1709,8 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
indices = jnp.array([[[0, 0], [1, 1], [3, 4], [4, 5]],
|
||||
[[2, 1], [1, 0], [2, 3], [1, 1]]])
|
||||
x2 = sparse.BCOO((data, indices), shape=(2, 3, 2))
|
||||
data_batched, indices_batched = sparse_bcoo._fix_oob_indices(
|
||||
x2.data, x2.indices, spinfo=x2._info)
|
||||
data_batched, indices_batched = sparse_bcoo._bcoo_correct_out_of_bound_indices(
|
||||
x2.data, x2.indices, rhs, shape=x2._info.shape)
|
||||
expected_data_batched = jnp.array([[0, 1, 0, 0],
|
||||
[4, 5, 0, 7]])
|
||||
expected_indices_batched = jnp.array(
|
||||
|
Loading…
x
Reference in New Issue
Block a user