mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[sparse] Use sorted indices instead of sorted rows only.
PiperOrigin-RevId: 440579642
This commit is contained in:
parent
e9f95fa5fa
commit
a11b41f581
@ -57,7 +57,8 @@ class COO(JAXSparse):
|
||||
nse = property(lambda self: self.data.size)
|
||||
dtype = property(lambda self: self.data.dtype)
|
||||
_info = property(lambda self: COOInfo(
|
||||
shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted))
|
||||
shape=self.shape, rows_sorted=self._rows_sorted,
|
||||
cols_sorted=self._cols_sorted))
|
||||
_bufs = property(lambda self: (self.data, self.row, self.col))
|
||||
_rows_sorted: bool
|
||||
_cols_sorted: bool
|
||||
@ -72,16 +73,18 @@ class COO(JAXSparse):
|
||||
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
|
||||
return coo_fromdense(mat, nse=nse, index_dtype=index_dtype)
|
||||
|
||||
def _sort_rows(self):
|
||||
"""Return a copy of the COO matrix with sorted rows.
|
||||
def _sort_indices(self):
|
||||
"""Return a copy of the COO matrix with sorted indices.
|
||||
|
||||
The matrix is sorted by row indices and column indices per row.
|
||||
If self._rows_sorted is True, this returns ``self`` without a copy.
|
||||
"""
|
||||
# TODO(jakevdp): would be benefit from lowering this to cusparse sort_rows utility?
|
||||
if self._rows_sorted:
|
||||
return self
|
||||
row, col, data = lax.sort((self.row, self.col, self.data), num_keys=1)
|
||||
return self.__class__((data, row, col), shape=self.shape, rows_sorted=True)
|
||||
row, col, data = lax.sort((self.row, self.col, self.data), num_keys=2)
|
||||
return self.__class__((data, row, col), shape=self.shape,
|
||||
rows_sorted=True)
|
||||
|
||||
@classmethod
|
||||
def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
|
||||
@ -91,7 +94,8 @@ class COO(JAXSparse):
|
||||
raise ValueError(f"COO must have ndim=2; got shape={shape}")
|
||||
data = jnp.empty(0, dtype)
|
||||
row = col = jnp.empty(0, index_dtype)
|
||||
return cls((data, row, col), shape=shape, rows_sorted=True, cols_sorted=True)
|
||||
return cls((data, row, col), shape=shape, rows_sorted=True,
|
||||
cols_sorted=True)
|
||||
|
||||
def todense(self):
|
||||
return coo_todense(self)
|
||||
@ -185,7 +189,6 @@ def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
|
||||
|
||||
return [xops.Transpose(result, (1, 0))] if transpose else [result]
|
||||
|
||||
_coo_todense_lowering = mlir.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False)
|
||||
|
||||
@ -263,7 +266,8 @@ def coo_fromdense(mat, *, nse=None, index_dtype=jnp.int32):
|
||||
if nse is None:
|
||||
nse = (mat != 0).sum()
|
||||
nse = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
|
||||
return COO(_coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape, rows_sorted=True)
|
||||
return COO(_coo_fromdense(mat, nse=nse, index_dtype=index_dtype),
|
||||
shape=mat.shape, rows_sorted=True)
|
||||
|
||||
def _coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
|
||||
"""Create COO-format sparse matrix from a dense matrix.
|
||||
@ -589,6 +593,7 @@ def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
|
||||
spinfo=spinfo, transpose=transpose)
|
||||
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
elif spinfo.cols_sorted:
|
||||
|
@ -438,7 +438,7 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
mat = sparse.COO.fromdense(sprng((5, 6), np.float32))
|
||||
perm = rng.permutation(mat.nse)
|
||||
mat_unsorted = sparse.COO((mat.data[perm], mat.row[perm], mat.col[perm]), shape=mat.shape)
|
||||
mat_resorted = mat_unsorted._sort_rows()
|
||||
mat_resorted = mat_unsorted._sort_indices()
|
||||
self.assertArraysEqual(mat.todense(), mat_resorted.todense())
|
||||
|
||||
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
|
||||
@ -460,15 +460,15 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
self.assertFalse(mat_unsorted._rows_sorted)
|
||||
self.assertFalse(mat_unsorted._cols_sorted)
|
||||
|
||||
self.assertArraysEqual(mat, mat_rows_sorted._sort_rows().todense())
|
||||
self.assertArraysEqual(mat, mat_cols_sorted._sort_rows().todense())
|
||||
self.assertArraysEqual(mat, mat_unsorted._sort_rows().todense())
|
||||
self.assertArraysEqual(mat, mat_rows_sorted._sort_indices().todense())
|
||||
self.assertArraysEqual(mat, mat_cols_sorted._sort_indices().todense())
|
||||
self.assertArraysEqual(mat, mat_unsorted._sort_indices().todense())
|
||||
|
||||
todense = jit(sparse.coo_todense)
|
||||
with self.assertNoWarnings():
|
||||
dense_rows_sorted = todense(mat_rows_sorted)
|
||||
dense_cols_sorted = todense(mat_cols_sorted)
|
||||
dense_unsorted = todense(mat_unsorted._sort_rows())
|
||||
dense_unsorted = todense(mat_unsorted._sort_indices())
|
||||
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_todense GPU lowering requires matrices with sorted rows.*"):
|
||||
dense_unsorted_fallback = todense(mat_unsorted)
|
||||
self.assertArraysEqual(mat, dense_rows_sorted)
|
||||
@ -482,7 +482,7 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
with self.assertNoWarnings():
|
||||
matvec_rows_sorted = matvec(mat_rows_sorted, rhs_vec)
|
||||
matvec_cols_sorted = matvec(mat_cols_sorted, rhs_vec)
|
||||
matvec_unsorted = matvec(mat_unsorted._sort_rows(), rhs_vec)
|
||||
matvec_unsorted = matvec(mat_unsorted._sort_indices(), rhs_vec)
|
||||
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matvec GPU lowering requires matrices with sorted rows.*"):
|
||||
matvec_unsorted_fallback = matvec(mat_unsorted, rhs_vec)
|
||||
self.assertArraysEqual(matvec_expected, matvec_rows_sorted)
|
||||
@ -496,7 +496,7 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
with self.assertNoWarnings():
|
||||
matmat_rows_sorted = matmat(mat_rows_sorted, rhs_mat)
|
||||
matmat_cols_sorted = matmat(mat_cols_sorted, rhs_mat)
|
||||
matmat_unsorted = matmat(mat_unsorted._sort_rows(), rhs_mat)
|
||||
matmat_unsorted = matmat(mat_unsorted._sort_indices(), rhs_mat)
|
||||
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matmat GPU lowering requires matrices with sorted rows.*"):
|
||||
matmat_unsorted_fallback = matmat(mat_unsorted, rhs_mat)
|
||||
self.assertArraysEqual(matmat_expected, matmat_rows_sorted)
|
||||
|
Loading…
x
Reference in New Issue
Block a user