[sparse] remove handling of padded indices from COO/CSR

This commit is contained in:
Jake VanderPlas 2023-02-22 15:15:02 -08:00
parent 2d93b28b18
commit bf1f5d21a2
3 changed files with 14 additions and 72 deletions

View File

@ -45,7 +45,6 @@ class COOInfo(NamedTuple):
shape: Shape
rows_sorted: bool = False
cols_sorted: bool = False
padded: bool = False
@tree_util.register_pytree_node_class
@ -55,6 +54,10 @@ class COO(JAXSparse):
Note: this class has minimal compatibility with JAX transforms such as
grad and autodiff, and offers very little functionality. In general you
should prefer :class:`jax.experimental.sparse.BCOO`.
Additionally, there are known failures in the case that `nse` is larger
than the true number of nonzeros in the represented matrix. This situation
is better handled in BCOO.
"""
data: jax.Array
row: jax.Array
@ -64,19 +67,16 @@ class COO(JAXSparse):
dtype = property(lambda self: self.data.dtype)
_info = property(lambda self: COOInfo(
shape=self.shape, rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted, padded=self._padded))
cols_sorted=self._cols_sorted))
_bufs = property(lambda self: (self.data, self.row, self.col))
_rows_sorted: bool
_cols_sorted: bool
_padded: bool
def __init__(self, args: Tuple[Array, Array, Array], *, shape: Shape,
rows_sorted: bool = False, cols_sorted: bool = False,
padded: bool = True):
rows_sorted: bool = False, cols_sorted: bool = False):
self.data, self.row, self.col = map(jnp.asarray, args)
self._rows_sorted = rows_sorted
self._cols_sorted = cols_sorted
self._padded = padded
super().__init__(args, shape=shape)
@classmethod
@ -135,7 +135,7 @@ class COO(JAXSparse):
if axes is not None:
raise NotImplementedError("axes argument to transpose()")
return COO((self.data, self.col, self.row), shape=self.shape[::-1],
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted, padded=self._padded)
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted)
def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]:
return (self.data, self.row, self.col), self._info._asdict()
@ -144,12 +144,11 @@ class COO(JAXSparse):
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.row, obj.col = children
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted', 'padded'}:
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}:
raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}")
obj.shape = aux_data['shape']
obj._rows_sorted = aux_data['rows_sorted']
obj._cols_sorted = aux_data['cols_sorted']
obj._padded = aux_data['padded']
return obj
def __matmul__(self, other: ArrayLike) -> Array:
@ -212,9 +211,6 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for {dtype=}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
if spinfo.padded:
# GPU rule returns incorrect results with padded representation.
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
if spinfo.rows_sorted:
shape = spinfo.shape
@ -282,12 +278,11 @@ def coo_fromdense(mat: Array, *, nse: Optional[int] = None, index_dtype: DTypeLi
Returns:
mat_coo : COO representation of the matrix.
"""
padded = nse is not None
if nse is None:
nse = int((mat != 0).sum())
nse_int = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
return COO(_coo_fromdense(mat, nse=nse_int, index_dtype=index_dtype),
shape=mat.shape, rows_sorted=True, padded=padded)
shape=mat.shape, rows_sorted=True)
def _coo_fromdense(mat: Array, *, nse: int, index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array, Array]:
"""Create COO-format sparse matrix from a dense matrix.
@ -456,9 +451,6 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo,
warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for {dtype=}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose)
if spinfo.padded:
# GPU rule returns incorrect results with padded representation.
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose)
if spinfo.rows_sorted:
shape = spinfo.shape
@ -581,9 +573,6 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo,
warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for {dtype=}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose)
if spinfo.padded:
# GPU rule returns incorrect results with padded representation.
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose)
if spinfo.rows_sorted:
shape = spinfo.shape

View File

@ -48,6 +48,10 @@ class CSR(JAXSparse):
Note: this class has minimal compatibility with JAX transforms such as
grad and autodiff, and offers very little functionality. In general you
should prefer :class:`jax.experimental.sparse.BCOO`.
Additionally, there are known failures in the case that `nse` is larger
than the true number of nonzeros in the represented matrix. This situation
is better handled in BCOO.
"""
data: jax.Array
indices: jax.Array

View File

@ -508,7 +508,7 @@ class cuSparseTest(sptu.SparseTestCase):
self.assertFalse(mat_cols_sorted._rows_sorted)
self.assertTrue(mat_cols_sorted._cols_sorted)
mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape, padded=False)
mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape)
self.assertFalse(mat_unsorted._rows_sorted)
self.assertFalse(mat_unsorted._cols_sorted)
@ -575,57 +575,6 @@ class cuSparseTest(sptu.SparseTestCase):
self.assertIn(sparse.csr_todense_p,
mlir._platform_specific_lowerings["rocm"])
@jtu.sample_product(
shape=[(5, 8), (8, 5), (5, 5), (8, 8)],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
mat_type=[sparse.CSR, sparse.COO],
)
def test_extra_nse(self, shape, dtype, mat_type):
rng = rand_sparse(self.rng())
rng_dense = jtu.rand_default(self.rng())
M = rng(shape, dtype)
nse = (M != 0).sum() + 5
M_sp = mat_type.fromdense(M, nse=nse)
with self.subTest("todense"):
def todense1(M, _):
assert isinstance(M, np.ndarray)
return M
def todense2(_, M):
assert isinstance(M, mat_type)
return M.todense()
args_maker = lambda: [M, M_sp]
self._CheckAgainstNumpy(todense1, todense2, args_maker)
self._CompileAndCheck(todense2, args_maker)
with self.subTest("matvec"):
v = rng_dense(M.shape[-1:], dtype)
args_maker = lambda: [M, M_sp, v]
def matvec1(M, _, v):
assert isinstance(M, np.ndarray)
return M @ v
def matvec2(_, M, v):
assert isinstance(M, mat_type)
return M @ v
self._CheckAgainstNumpy(matvec1, matvec2, args_maker)
self._CompileAndCheck(matvec2, args_maker)
with self.subTest("matmat"):
B = rng_dense(M.shape[::-1], dtype)
args_maker = lambda: [M, M_sp, B]
def matmat1(M, _, B):
assert isinstance(M, np.ndarray)
return M @ B
def matmat2(_, M, B):
assert isinstance(M, mat_type)
return M @ B
if dtype == np.dtype(np.float64):
tol = 1e-14 # Lower the precision a tiny bit to avoid flakiness.
else:
tol = None
self._CheckAgainstNumpy(matmat1, matmat2, args_maker, tol=tol)
self._CompileAndCheck(matmat2, args_maker, tol=tol)
@jtu.sample_product(
shape=[(5, 8), (8, 5), (5, 5), (8, 8)],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,