mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] test coo/csr extra nse
This commit is contained in:
parent
2b9ad0d93e
commit
df358242ff
@ -45,6 +45,7 @@ class COOInfo(NamedTuple):
|
||||
shape: Shape
|
||||
rows_sorted: bool = False
|
||||
cols_sorted: bool = False
|
||||
padded: bool = False
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@ -63,16 +64,19 @@ class COO(JAXSparse):
|
||||
dtype = property(lambda self: self.data.dtype)
|
||||
_info = property(lambda self: COOInfo(
|
||||
shape=self.shape, rows_sorted=self._rows_sorted,
|
||||
cols_sorted=self._cols_sorted))
|
||||
cols_sorted=self._cols_sorted, padded=self._padded))
|
||||
_bufs = property(lambda self: (self.data, self.row, self.col))
|
||||
_rows_sorted: bool
|
||||
_cols_sorted: bool
|
||||
_padded: bool
|
||||
|
||||
def __init__(self, args: Tuple[Array, Array, Array], *, shape: Shape,
|
||||
rows_sorted: bool = False, cols_sorted: bool = False):
|
||||
rows_sorted: bool = False, cols_sorted: bool = False,
|
||||
padded: bool = True):
|
||||
self.data, self.row, self.col = map(jnp.asarray, args)
|
||||
self._rows_sorted = rows_sorted
|
||||
self._cols_sorted = cols_sorted
|
||||
self._padded = padded
|
||||
super().__init__(args, shape=shape)
|
||||
|
||||
@classmethod
|
||||
@ -131,7 +135,7 @@ class COO(JAXSparse):
|
||||
if axes is not None:
|
||||
raise NotImplementedError("axes argument to transpose()")
|
||||
return COO((self.data, self.col, self.row), shape=self.shape[::-1],
|
||||
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted)
|
||||
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted, padded=self._padded)
|
||||
|
||||
def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]:
|
||||
return (self.data, self.row, self.col), self._info._asdict()
|
||||
@ -140,11 +144,12 @@ class COO(JAXSparse):
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
obj = object.__new__(cls)
|
||||
obj.data, obj.row, obj.col = children
|
||||
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}:
|
||||
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted', 'padded'}:
|
||||
raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}")
|
||||
obj.shape = aux_data['shape']
|
||||
obj._rows_sorted = aux_data['rows_sorted']
|
||||
obj._cols_sorted = aux_data['cols_sorted']
|
||||
obj._padded = aux_data['padded']
|
||||
return obj
|
||||
|
||||
def __matmul__(self, other: ArrayLike) -> Array:
|
||||
@ -207,6 +212,9 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
|
||||
warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for {dtype=}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
|
||||
if spinfo.padded:
|
||||
# GPU rule returns incorrect results with padded representation.
|
||||
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
|
||||
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
@ -274,11 +282,12 @@ def coo_fromdense(mat: Array, *, nse: Optional[int] = None, index_dtype: DTypeLi
|
||||
Returns:
|
||||
mat_coo : COO representation of the matrix.
|
||||
"""
|
||||
padded = nse is not None
|
||||
if nse is None:
|
||||
nse = int((mat != 0).sum())
|
||||
nse_int = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
|
||||
return COO(_coo_fromdense(mat, nse=nse_int, index_dtype=index_dtype),
|
||||
shape=mat.shape, rows_sorted=True)
|
||||
shape=mat.shape, rows_sorted=True, padded=padded)
|
||||
|
||||
def _coo_fromdense(mat: Array, *, nse: int, index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array, Array]:
|
||||
"""Create COO-format sparse matrix from a dense matrix.
|
||||
@ -446,8 +455,10 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo,
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for {dtype=}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo,
|
||||
transpose=transpose)
|
||||
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose)
|
||||
if spinfo.padded:
|
||||
# GPU rule returns incorrect results with padded representation.
|
||||
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose)
|
||||
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
@ -569,8 +580,11 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo,
|
||||
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for {dtype=}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo,
|
||||
transpose=transpose)
|
||||
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose)
|
||||
if spinfo.padded:
|
||||
# GPU rule returns incorrect results with padded representation.
|
||||
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose)
|
||||
|
||||
if spinfo.rows_sorted:
|
||||
shape = spinfo.shape
|
||||
elif spinfo.cols_sorted:
|
||||
|
@ -508,7 +508,7 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
self.assertFalse(mat_cols_sorted._rows_sorted)
|
||||
self.assertTrue(mat_cols_sorted._cols_sorted)
|
||||
|
||||
mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape)
|
||||
mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape, padded=False)
|
||||
self.assertFalse(mat_unsorted._rows_sorted)
|
||||
self.assertFalse(mat_unsorted._cols_sorted)
|
||||
|
||||
@ -582,10 +582,45 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
)
|
||||
def test_extra_nse(self, shape, dtype, mat_type):
|
||||
rng = rand_sparse(self.rng())
|
||||
rng_dense = jtu.rand_default(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
nse = (M != 0).sum() + 5
|
||||
M_out = mat_type.fromdense(M, nse=nse, index_dtype=jnp.int32).todense()
|
||||
self.assertArraysEqual(M, M_out)
|
||||
M_sp = mat_type.fromdense(M, nse=nse)
|
||||
|
||||
with self.subTest("todense"):
|
||||
def todense1(M, _):
|
||||
assert isinstance(M, np.ndarray)
|
||||
return M
|
||||
def todense2(_, M):
|
||||
assert isinstance(M, mat_type)
|
||||
return M.todense()
|
||||
args_maker = lambda: [M, M_sp]
|
||||
self._CheckAgainstNumpy(todense1, todense2, args_maker)
|
||||
self._CompileAndCheck(todense2, args_maker)
|
||||
|
||||
with self.subTest("matvec"):
|
||||
v = rng_dense(M.shape[-1:], dtype)
|
||||
args_maker = lambda: [M, M_sp, v]
|
||||
def matvec1(M, _, v):
|
||||
assert isinstance(M, np.ndarray)
|
||||
return M @ v
|
||||
def matvec2(_, M, v):
|
||||
assert isinstance(M, mat_type)
|
||||
return M @ v
|
||||
self._CheckAgainstNumpy(matvec1, matvec2, args_maker)
|
||||
self._CompileAndCheck(matvec2, args_maker)
|
||||
|
||||
with self.subTest("matmat"):
|
||||
B = rng_dense(M.shape[::-1], dtype)
|
||||
args_maker = lambda: [M, M_sp, B]
|
||||
def matmat1(M, _, B):
|
||||
assert isinstance(M, np.ndarray)
|
||||
return M @ B
|
||||
def matmat2(_, M, B):
|
||||
assert isinstance(M, mat_type)
|
||||
return M @ B
|
||||
self._CheckAgainstNumpy(matmat1, matmat2, args_maker)
|
||||
self._CompileAndCheck(matmat2, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(5, 8), (8, 5), (5, 5), (8, 8)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user