[sparse] Use sorted indices instead of sorted rows only.

PiperOrigin-RevId: 440579642
This commit is contained in:
Tianjian Lu 2022-04-09 08:33:20 -07:00 committed by jax authors
parent e9f95fa5fa
commit a11b41f581
2 changed files with 20 additions and 15 deletions

View File

@ -57,7 +57,8 @@ class COO(JAXSparse):
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)
_info = property(lambda self: COOInfo(
shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted))
shape=self.shape, rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted))
_bufs = property(lambda self: (self.data, self.row, self.col))
_rows_sorted: bool
_cols_sorted: bool
@ -72,16 +73,18 @@ class COO(JAXSparse):
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
return coo_fromdense(mat, nse=nse, index_dtype=index_dtype)
def _sort_rows(self):
"""Return a copy of the COO matrix with sorted rows.
def _sort_indices(self):
"""Return a copy of the COO matrix with sorted indices.
The matrix is sorted by row indices and column indices per row.
If self._rows_sorted is True, this returns ``self`` without a copy.
"""
# TODO(jakevdp): would be benefit from lowering this to cusparse sort_rows utility?
if self._rows_sorted:
return self
row, col, data = lax.sort((self.row, self.col, self.data), num_keys=1)
return self.__class__((data, row, col), shape=self.shape, rows_sorted=True)
row, col, data = lax.sort((self.row, self.col, self.data), num_keys=2)
return self.__class__((data, row, col), shape=self.shape,
rows_sorted=True)
@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
@ -91,7 +94,8 @@ class COO(JAXSparse):
raise ValueError(f"COO must have ndim=2; got shape={shape}")
data = jnp.empty(0, dtype)
row = col = jnp.empty(0, index_dtype)
return cls((data, row, col), shape=shape, rows_sorted=True, cols_sorted=True)
return cls((data, row, col), shape=shape, rows_sorted=True,
cols_sorted=True)
def todense(self):
return coo_todense(self)
@ -185,7 +189,6 @@ def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
return [xops.Transpose(result, (1, 0))] if transpose else [result]
_coo_todense_lowering = mlir.lower_fun(
_coo_todense_impl, multiple_results=False)
@ -263,7 +266,8 @@ def coo_fromdense(mat, *, nse=None, index_dtype=jnp.int32):
if nse is None:
nse = (mat != 0).sum()
nse = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
return COO(_coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape, rows_sorted=True)
return COO(_coo_fromdense(mat, nse=nse, index_dtype=index_dtype),
shape=mat.shape, rows_sorted=True)
def _coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
"""Create COO-format sparse matrix from a dense matrix.
@ -589,6 +593,7 @@ def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
spinfo=spinfo, transpose=transpose)
if spinfo.rows_sorted:
shape = spinfo.shape
elif spinfo.cols_sorted:

View File

@ -438,7 +438,7 @@ class cuSparseTest(jtu.JaxTestCase):
mat = sparse.COO.fromdense(sprng((5, 6), np.float32))
perm = rng.permutation(mat.nse)
mat_unsorted = sparse.COO((mat.data[perm], mat.row[perm], mat.col[perm]), shape=mat.shape)
mat_resorted = mat_unsorted._sort_rows()
mat_resorted = mat_unsorted._sort_indices()
self.assertArraysEqual(mat.todense(), mat_resorted.todense())
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
@ -460,15 +460,15 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertFalse(mat_unsorted._rows_sorted)
self.assertFalse(mat_unsorted._cols_sorted)
self.assertArraysEqual(mat, mat_rows_sorted._sort_rows().todense())
self.assertArraysEqual(mat, mat_cols_sorted._sort_rows().todense())
self.assertArraysEqual(mat, mat_unsorted._sort_rows().todense())
self.assertArraysEqual(mat, mat_rows_sorted._sort_indices().todense())
self.assertArraysEqual(mat, mat_cols_sorted._sort_indices().todense())
self.assertArraysEqual(mat, mat_unsorted._sort_indices().todense())
todense = jit(sparse.coo_todense)
with self.assertNoWarnings():
dense_rows_sorted = todense(mat_rows_sorted)
dense_cols_sorted = todense(mat_cols_sorted)
dense_unsorted = todense(mat_unsorted._sort_rows())
dense_unsorted = todense(mat_unsorted._sort_indices())
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_todense GPU lowering requires matrices with sorted rows.*"):
dense_unsorted_fallback = todense(mat_unsorted)
self.assertArraysEqual(mat, dense_rows_sorted)
@ -482,7 +482,7 @@ class cuSparseTest(jtu.JaxTestCase):
with self.assertNoWarnings():
matvec_rows_sorted = matvec(mat_rows_sorted, rhs_vec)
matvec_cols_sorted = matvec(mat_cols_sorted, rhs_vec)
matvec_unsorted = matvec(mat_unsorted._sort_rows(), rhs_vec)
matvec_unsorted = matvec(mat_unsorted._sort_indices(), rhs_vec)
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matvec GPU lowering requires matrices with sorted rows.*"):
matvec_unsorted_fallback = matvec(mat_unsorted, rhs_vec)
self.assertArraysEqual(matvec_expected, matvec_rows_sorted)
@ -496,7 +496,7 @@ class cuSparseTest(jtu.JaxTestCase):
with self.assertNoWarnings():
matmat_rows_sorted = matmat(mat_rows_sorted, rhs_mat)
matmat_cols_sorted = matmat(mat_cols_sorted, rhs_mat)
matmat_unsorted = matmat(mat_unsorted._sort_rows(), rhs_mat)
matmat_unsorted = matmat(mat_unsorted._sort_indices(), rhs_mat)
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matmat GPU lowering requires matrices with sorted rows.*"):
matmat_unsorted_fallback = matmat(mat_unsorted, rhs_mat)
self.assertArraysEqual(matmat_expected, matmat_rows_sorted)