mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] more robust correction for out-of-bound indices in BCOO
Previously we were setting out-of-bound indices to zero, which works in most (but not all) cases. The problem is that if (0, 0) is a defined matrix element, these subsequent zeros effectively overwrite this element in some cusparse routines. The fix here is to add another row or column to the matrix as necessary, and to push these undefined values into that row/col, where they can be sliced off at the end of the cusparse operation so that they will not affect the computation of interest. PiperOrigin-RevId: 513639921
This commit is contained in:
parent
17079d9072
commit
33c0a103c6
@ -813,22 +813,25 @@ def _bcoo_dot_general_gpu_impl(lhs_data, lhs_indices, rhs, *,
|
||||
if (len(lhs_contract) == 1 and len(lhs_batch) == 0 and rhs.ndim in (1, 2)
|
||||
and (n_batch, n_sparse, n_dense) == (0, 1, 0)
|
||||
and not _bcoo_dot_general_fallback(lhs_data, lhs_indices, lhs_spinfo)):
|
||||
data, indices = _bcoo_correct_out_of_bound_indices(lhs_data, lhs_indices, lhs_spinfo.shape)
|
||||
row, col = jnp.zeros(indices.shape[0], indices.dtype), indices.ravel()
|
||||
out = coo_matmul_p.bind(data, row, col,
|
||||
row, col = jnp.zeros(lhs_indices.shape[0], lhs_indices.dtype), lhs_indices.ravel()
|
||||
transpose = False
|
||||
shape = (1, *lhs_spinfo.shape)
|
||||
row, col, shape = _coo_correct_out_of_bound_indices(row, col, shape, transpose)
|
||||
out = coo_matmul_p.bind(lhs_data, row, col,
|
||||
rhs.T if rhs_contract[0] == 1 else rhs,
|
||||
transpose=False,
|
||||
shape=(1, *lhs_spinfo.shape))
|
||||
return lax.squeeze(out, (0,))
|
||||
transpose=transpose, shape=shape)
|
||||
return out[0]
|
||||
elif (len(lhs_contract) == 1 and len(lhs_batch) == 0 and rhs.ndim in (1, 2)
|
||||
and (n_batch, n_sparse, n_dense) == (0, 2, 0)
|
||||
and not _bcoo_dot_general_fallback(lhs_data, lhs_indices, lhs_spinfo)):
|
||||
data, indices = _bcoo_correct_out_of_bound_indices(lhs_data, lhs_indices, lhs_spinfo.shape)
|
||||
row, col = indices[:, 0], indices[:, 1]
|
||||
return coo_matmul_p.bind(lhs_data, row, col,
|
||||
rhs.T if rhs_contract[0] == 1 else rhs,
|
||||
transpose=(lhs_contract[0] == 0),
|
||||
shape=lhs_spinfo.shape)
|
||||
row, col = lhs_indices[:, 0], lhs_indices[:, 1]
|
||||
transpose = (lhs_contract[0] == 0)
|
||||
shape = lhs_spinfo.shape
|
||||
row, col, shape = _coo_correct_out_of_bound_indices(row, col, shape, transpose)
|
||||
out = coo_matmul_p.bind(lhs_data, row, col,
|
||||
rhs.T if rhs_contract[0] == 1 else rhs,
|
||||
transpose=transpose, shape=shape)
|
||||
return out[:-1]
|
||||
else:
|
||||
return _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
@ -1428,17 +1431,26 @@ def _unique_indices_unbatched(indices, *, shape, return_inverse=False,
|
||||
out = (*out[:-1], nse)
|
||||
return out
|
||||
|
||||
def _bcoo_correct_out_of_bound_indices(data, indices, shape):
|
||||
"""Set out-of-bound (OOB) indices and the corresponding data to zeros."""
|
||||
props = _validate_bcoo(data, indices, shape)
|
||||
assert props.n_dense == 0, "not implemented"
|
||||
if props.n_batch:
|
||||
f = partial(_bcoo_correct_out_of_bound_indices, shape=shape[props.n_batch:])
|
||||
return nfold_vmap(f, props.n_batch)(data, indices)
|
||||
mask = indices >= jnp.array(shape[:props.n_sparse], dtype=indices.dtype)[None, :]
|
||||
new_indices = jnp.where(mask, 0, indices)
|
||||
new_data = jnp.where(mask.any(-1), 0, data)
|
||||
return new_data, new_indices
|
||||
def _coo_correct_out_of_bound_indices(row, col, shape, transpose):
|
||||
# Since cusparse does not have any well-tested support for padded indices,
|
||||
# we push them into an extra row/col of the matrix, which will then be
|
||||
# sliced away in the output.
|
||||
assert row.ndim == col.ndim, f"{row.ndim} != {col.ndim}"
|
||||
assert len(shape) == row.ndim + 1, f"{len(shape)} != {row.ndim + 1}"
|
||||
if row.ndim > 1:
|
||||
f = partial(_coo_correct_out_of_bound_indices,
|
||||
shape=shape[row.ndim:], transpose=transpose)
|
||||
return nfold_vmap(f, row.ndim)(row, col)
|
||||
mask = (row > shape[0]) | (col > shape[1])
|
||||
if transpose:
|
||||
row = jnp.where(mask, 0, row)
|
||||
col = jnp.where(mask, shape[1], col)
|
||||
shape = (shape[0], shape[1] + 1)
|
||||
else:
|
||||
row = jnp.where(mask, shape[0], row)
|
||||
col = jnp.where(mask, 0, col)
|
||||
shape = (shape[0] + 1, shape[1])
|
||||
return row, col, shape
|
||||
|
||||
@bcoo_sum_duplicates_p.def_abstract_eval
|
||||
def _bcoo_sum_duplicates_abstract_eval(data, indices, *, spinfo, nse):
|
||||
|
@ -1745,36 +1745,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self.assertArraysEqual(x.indices, y.indices)
|
||||
self.assertArraysEqual(x.data, y.data)
|
||||
|
||||
def test_bcoo_correct_out_of_bound_indices(self):
|
||||
data = jnp.array([1, 2, 3, 4, 0, 0])
|
||||
indices = jnp.array([1, 3, 0, 2, 2, 4])[:, None]
|
||||
x1 = sparse.BCOO((data, indices), shape=(2,))
|
||||
data_unbatched, indices_unbatched = sparse_bcoo._bcoo_correct_out_of_bound_indices(
|
||||
x1.data, x1.indices, x1._info.shape)
|
||||
expected_data_unbatched = jnp.array([1, 0, 3, 0, 0, 0])
|
||||
expected_indices_unbatched = jnp.array([[1], [0], [0], [0], [0], [0]])
|
||||
with self.subTest('unbatched data'):
|
||||
self.assertArraysEqual(data_unbatched, expected_data_unbatched)
|
||||
with self.subTest('unbatched indices'):
|
||||
self.assertArraysEqual(indices_unbatched, expected_indices_unbatched)
|
||||
|
||||
data = jnp.array([[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]])
|
||||
indices = jnp.array([[[0, 0], [1, 1], [3, 4], [4, 5]],
|
||||
[[2, 1], [1, 0], [2, 3], [1, 1]]])
|
||||
x2 = sparse.BCOO((data, indices), shape=(2, 3, 2))
|
||||
data_batched, indices_batched = sparse_bcoo._bcoo_correct_out_of_bound_indices(
|
||||
x2.data, x2.indices, x2._info.shape)
|
||||
expected_data_batched = jnp.array([[0, 1, 0, 0],
|
||||
[4, 5, 0, 7]])
|
||||
expected_indices_batched = jnp.array(
|
||||
[[[0, 0], [1, 1], [0, 0], [0, 0]],
|
||||
[[2, 1], [1, 0], [2, 0], [1, 1]]])
|
||||
with self.subTest('batched data'):
|
||||
self.assertArraysEqual(data_batched, expected_data_batched)
|
||||
with self.subTest('batched indices'):
|
||||
self.assertArraysEqual(indices_batched, expected_indices_batched)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, axes=axes)
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
@ -1935,14 +1905,6 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
@jax.default_matmul_precision("float32")
|
||||
@jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning)
|
||||
def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
|
||||
# TODO(b/259538729): Disable gpu test when type promotion is required.
|
||||
# BCOO type promotion calls `convert_element_type`, which further calls
|
||||
# `sum_duplicates` and creates padding with out-of-bound indices.
|
||||
# `bcoo_dot_general` cusparse lowering is not able to handle out-of-bound
|
||||
# indices right now.
|
||||
if jtu.device_under_test() == "gpu" and lhs_dtype != rhs_dtype:
|
||||
raise self.skipTest("Disable gpu test when type promotion is required")
|
||||
|
||||
# Note: currently, batch dimensions in matmul must correspond to batch
|
||||
# dimensions in the sparse representation.
|
||||
n_batch_lhs = max(0, len(lhs_shape) - 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user