mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] remove handling of padded indices from COO/CSR
This commit is contained in:
parent
2d93b28b18
commit
bf1f5d21a2
@ -45,7 +45,6 @@ class COOInfo(NamedTuple):
|
||||
shape: Shape
|
||||
rows_sorted: bool = False
|
||||
cols_sorted: bool = False
|
||||
padded: bool = False
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@ -55,6 +54,10 @@ class COO(JAXSparse):
|
||||
Note: this class has minimal compatibility with JAX transforms such as
|
||||
grad and autodiff, and offers very little functionality. In general you
|
||||
should prefer :class:`jax.experimental.sparse.BCOO`.
|
||||
|
||||
Additionally, there are known failures in the case that `nse` is larger
|
||||
than the true number of nonzeros in the represented matrix. This situation
|
||||
is better handled in BCOO.
|
||||
"""
|
||||
data: jax.Array
|
||||
row: jax.Array
|
||||
@ -64,19 +67,16 @@ class COO(JAXSparse):
|
||||
dtype = property(lambda self: self.data.dtype)
|
||||
_info = property(lambda self: COOInfo(
|
||||
shape=self.shape, rows_sorted=self._rows_sorted,
|
||||
cols_sorted=self._cols_sorted, padded=self._padded))
|
||||
cols_sorted=self._cols_sorted))
|
||||
_bufs = property(lambda self: (self.data, self.row, self.col))
|
||||
_rows_sorted: bool
|
||||
_cols_sorted: bool
|
||||
_padded: bool
|
||||
|
||||
def __init__(self, args: Tuple[Array, Array, Array], *, shape: Shape,
|
||||
rows_sorted: bool = False, cols_sorted: bool = False,
|
||||
padded: bool = True):
|
||||
rows_sorted: bool = False, cols_sorted: bool = False):
|
||||
self.data, self.row, self.col = map(jnp.asarray, args)
|
||||
self._rows_sorted = rows_sorted
|
||||
self._cols_sorted = cols_sorted
|
||||
self._padded = padded
|
||||
super().__init__(args, shape=shape)
|
||||
|
||||
@classmethod
|
||||
@ -135,7 +135,7 @@ class COO(JAXSparse):
|
||||
if axes is not None:
|
||||
raise NotImplementedError("axes argument to transpose()")
|
||||
return COO((self.data, self.col, self.row), shape=self.shape[::-1],
|
||||
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted, padded=self._padded)
|
||||
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted)
|
||||
|
||||
def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]:
|
||||
return (self.data, self.row, self.col), self._info._asdict()
|
||||
@ -144,12 +144,11 @@ class COO(JAXSparse):
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
obj = object.__new__(cls)
|
||||
obj.data, obj.row, obj.col = children
|
||||
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted', 'padded'}:
|
||||
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}:
|
||||
raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}")
|
||||
obj.shape = aux_data['shape']
|
||||
obj._rows_sorted = aux_data['rows_sorted']
|
||||
obj._cols_sorted = aux_data['cols_sorted']
|
||||
obj._padded = aux_data['padded']
|
||||
return obj
|
||||
|
||||
def __matmul__(self, other: ArrayLike) -> Array:
|
||||
@ -212,9 +211,6 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
|
||||
warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for {dtype=}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
|
||||
if spinfo.padded:
|
||||
# GPU rule returns incorrect results with padded representation.
|
||||
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
|
||||
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
@ -282,12 +278,11 @@ def coo_fromdense(mat: Array, *, nse: Optional[int] = None, index_dtype: DTypeLi
|
||||
Returns:
|
||||
mat_coo : COO representation of the matrix.
|
||||
"""
|
||||
padded = nse is not None
|
||||
if nse is None:
|
||||
nse = int((mat != 0).sum())
|
||||
nse_int = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
|
||||
return COO(_coo_fromdense(mat, nse=nse_int, index_dtype=index_dtype),
|
||||
shape=mat.shape, rows_sorted=True, padded=padded)
|
||||
shape=mat.shape, rows_sorted=True)
|
||||
|
||||
def _coo_fromdense(mat: Array, *, nse: int, index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array, Array]:
|
||||
"""Create COO-format sparse matrix from a dense matrix.
|
||||
@ -456,9 +451,6 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo,
|
||||
warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for {dtype=}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose)
|
||||
if spinfo.padded:
|
||||
# GPU rule returns incorrect results with padded representation.
|
||||
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose)
|
||||
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
@ -581,9 +573,6 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo,
|
||||
warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for {dtype=}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose)
|
||||
if spinfo.padded:
|
||||
# GPU rule returns incorrect results with padded representation.
|
||||
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose)
|
||||
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
|
@ -48,6 +48,10 @@ class CSR(JAXSparse):
|
||||
Note: this class has minimal compatibility with JAX transforms such as
|
||||
grad and autodiff, and offers very little functionality. In general you
|
||||
should prefer :class:`jax.experimental.sparse.BCOO`.
|
||||
|
||||
Additionally, there are known failures in the case that `nse` is larger
|
||||
than the true number of nonzeros in the represented matrix. This situation
|
||||
is better handled in BCOO.
|
||||
"""
|
||||
data: jax.Array
|
||||
indices: jax.Array
|
||||
|
@ -508,7 +508,7 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
self.assertFalse(mat_cols_sorted._rows_sorted)
|
||||
self.assertTrue(mat_cols_sorted._cols_sorted)
|
||||
|
||||
mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape, padded=False)
|
||||
mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape)
|
||||
self.assertFalse(mat_unsorted._rows_sorted)
|
||||
self.assertFalse(mat_unsorted._cols_sorted)
|
||||
|
||||
@ -575,57 +575,6 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
self.assertIn(sparse.csr_todense_p,
|
||||
mlir._platform_specific_lowerings["rocm"])
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(5, 8), (8, 5), (5, 5), (8, 8)],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
mat_type=[sparse.CSR, sparse.COO],
|
||||
)
|
||||
def test_extra_nse(self, shape, dtype, mat_type):
|
||||
rng = rand_sparse(self.rng())
|
||||
rng_dense = jtu.rand_default(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
nse = (M != 0).sum() + 5
|
||||
M_sp = mat_type.fromdense(M, nse=nse)
|
||||
|
||||
with self.subTest("todense"):
|
||||
def todense1(M, _):
|
||||
assert isinstance(M, np.ndarray)
|
||||
return M
|
||||
def todense2(_, M):
|
||||
assert isinstance(M, mat_type)
|
||||
return M.todense()
|
||||
args_maker = lambda: [M, M_sp]
|
||||
self._CheckAgainstNumpy(todense1, todense2, args_maker)
|
||||
self._CompileAndCheck(todense2, args_maker)
|
||||
|
||||
with self.subTest("matvec"):
|
||||
v = rng_dense(M.shape[-1:], dtype)
|
||||
args_maker = lambda: [M, M_sp, v]
|
||||
def matvec1(M, _, v):
|
||||
assert isinstance(M, np.ndarray)
|
||||
return M @ v
|
||||
def matvec2(_, M, v):
|
||||
assert isinstance(M, mat_type)
|
||||
return M @ v
|
||||
self._CheckAgainstNumpy(matvec1, matvec2, args_maker)
|
||||
self._CompileAndCheck(matvec2, args_maker)
|
||||
|
||||
with self.subTest("matmat"):
|
||||
B = rng_dense(M.shape[::-1], dtype)
|
||||
args_maker = lambda: [M, M_sp, B]
|
||||
def matmat1(M, _, B):
|
||||
assert isinstance(M, np.ndarray)
|
||||
return M @ B
|
||||
def matmat2(_, M, B):
|
||||
assert isinstance(M, mat_type)
|
||||
return M @ B
|
||||
if dtype == np.dtype(np.float64):
|
||||
tol = 1e-14 # Lower the precision a tiny bit to avoid flakiness.
|
||||
else:
|
||||
tol = None
|
||||
self._CheckAgainstNumpy(matmat1, matmat2, args_maker, tol=tol)
|
||||
self._CompileAndCheck(matmat2, args_maker, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(5, 8), (8, 5), (5, 5), (8, 8)],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
|
Loading…
x
Reference in New Issue
Block a user