mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] add BCSR.to_bcoo and from_bcoo methods
This commit is contained in:
parent
65ef487a82
commit
5b0329daa8
@ -151,25 +151,6 @@ def _validate_bcoo_indices(indices: Buffer, shape: Sequence[int]) -> BCOOPropert
|
||||
return BCOOProperties(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense, nse=nse)
|
||||
|
||||
|
||||
def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int],
|
||||
index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array]:
|
||||
"""Given BCOO (indices), return BCSR (indices, indptr)."""
|
||||
n_batch, n_sparse, _, _ = _validate_bcoo_indices(indices, shape)
|
||||
|
||||
if n_sparse != 2:
|
||||
raise ValueError("Must have 2 sparse dimensions to be converted to BCSR.")
|
||||
|
||||
n_rows = shape[n_batch]
|
||||
|
||||
@partial(nfold_vmap, N=n_batch, broadcasted=False)
|
||||
def get_ptr(i):
|
||||
indptr = jnp.zeros(n_rows + 1, index_dtype)
|
||||
return indptr.at[1:].set(jnp.cumsum(
|
||||
jnp.bincount(i, length=n_rows).astype(index_dtype)))
|
||||
|
||||
return indices[..., 1], get_ptr(indices[..., 0])
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_todense
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""BCSR (Bached compressed row) matrix object and associated primitives."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
|
||||
from typing import NamedTuple, Optional, Sequence, Tuple, Union
|
||||
@ -94,6 +95,28 @@ def _bcsr_to_bcoo(indices: jnp.ndarray, indptr: jnp.ndarray, *,
|
||||
return jnp.stack(csr_to_coo(indices, indptr), axis=indices.ndim)
|
||||
|
||||
|
||||
def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int],
|
||||
index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array]:
|
||||
"""Given BCOO (indices), return BCSR (indices, indptr).
|
||||
|
||||
Note: this assumes that ``indices`` are lexicographically sorted within each batch.
|
||||
"""
|
||||
n_batch, n_sparse, _, _ = bcoo._validate_bcoo_indices(indices, shape)
|
||||
|
||||
if n_sparse != 2:
|
||||
raise ValueError("Must have 2 sparse dimensions to be converted to BCSR.")
|
||||
|
||||
n_rows = shape[n_batch]
|
||||
|
||||
@partial(nfold_vmap, N=n_batch, broadcasted=False)
|
||||
def get_ptr(i):
|
||||
indptr = jnp.zeros(n_rows + 1, index_dtype)
|
||||
return indptr.at[1:].set(jnp.cumsum(
|
||||
jnp.bincount(i, length=n_rows).astype(index_dtype)))
|
||||
|
||||
return indices[..., 1], get_ptr(indices[..., 0])
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# bcsr_fromdense
|
||||
bcsr_fromdense_p = core.Primitive('bcsr_fromdense')
|
||||
@ -165,7 +188,7 @@ def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
|
||||
raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.")
|
||||
bcoo_mat = bcoo.bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype,
|
||||
n_dense=n_dense, n_batch=n_batch)
|
||||
indices, indptr = bcoo._bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape)
|
||||
indices, indptr = _bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape)
|
||||
return bcoo_mat.data, indices, indptr
|
||||
|
||||
|
||||
@ -539,6 +562,19 @@ class BCSR(JAXSparse):
|
||||
"""Create a dense version of the array."""
|
||||
return bcsr_todense(self)
|
||||
|
||||
def to_bcoo(self) -> bcoo.BCOO:
|
||||
coo_indices = _bcsr_to_bcoo(self.indices, self.indptr, shape=self.shape)
|
||||
return bcoo.BCOO((self.data, coo_indices), shape=self.shape)
|
||||
|
||||
@classmethod
|
||||
def from_bcoo(cls, arr: bcoo.BCOO) -> BCSR:
|
||||
if arr.n_sparse != 2:
|
||||
raise NotImplementedError(f"BSCR.from_bcoo requires n_sparse=2; got {arr.n_sparse=}")
|
||||
if not arr.indices_sorted:
|
||||
arr = arr.sort_indices()
|
||||
indices, indptr = _bcoo_to_bcsr(arr.indices, shape=arr.shape)
|
||||
return cls((arr.data, indices, indptr), shape=arr.shape)
|
||||
|
||||
@classmethod
|
||||
def from_scipy_sparse(cls, mat, *, index_dtype=None, n_dense=0, n_batch=0):
|
||||
"""Create a BCSR array from a :mod:`scipy.sparse` array."""
|
||||
|
@ -27,6 +27,7 @@ from jax import tree_util
|
||||
from jax.util import safe_zip, split_list
|
||||
from jax.experimental import sparse
|
||||
from jax.experimental.sparse import bcoo as sparse_bcoo
|
||||
from jax.experimental.sparse import bcsr as sparse_bcsr
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@ -136,18 +137,23 @@ def _rand_sparse(shape: Sequence[int], dtype: DTypeLike, *,
|
||||
if 0 <= nse < 1:
|
||||
nse = int(np.ceil(nse * np.prod(sparse_shape)))
|
||||
data_rng = rand_method(rng)
|
||||
index_shape = (*batch_shape, nse, n_sparse)
|
||||
data_shape = (*batch_shape, nse, *dense_shape)
|
||||
bcoo_indices = jnp.array(
|
||||
rng.randint(0, sparse_shape, size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
|
||||
data = jnp.array(data_rng(data_shape, dtype))
|
||||
|
||||
if sparse_format == 'bcoo':
|
||||
return sparse.BCOO((data, bcoo_indices), shape=shape)
|
||||
|
||||
bcsr_indices, bcsr_indptr = sparse_bcoo._bcoo_to_bcsr(
|
||||
bcoo_indices, shape=shape)
|
||||
return sparse.BCSR((data, bcsr_indices, bcsr_indptr), shape=shape)
|
||||
index_shape = (*batch_shape, nse, n_sparse)
|
||||
indices = jnp.array(
|
||||
rng.randint(0, sparse_shape, size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
|
||||
return sparse.BCOO((data, indices), shape=shape)
|
||||
else:
|
||||
index_shape = (*batch_shape, nse)
|
||||
indptr_shape = (*batch_shape, sparse_shape[0] + 1)
|
||||
indices = jnp.array(
|
||||
rng.randint(0, sparse_shape[1], size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
|
||||
indptr = jnp.sort(
|
||||
rng.randint(0, nse + 1, size=indptr_shape, dtype=np.int32), axis=-1) # type: ignore[call-overload]
|
||||
indptr = indptr.at[..., 0].set(0)
|
||||
return sparse.BCSR((data, indices, indptr), shape=shape)
|
||||
|
||||
def rand_bcoo(rng: np.random.RandomState,
|
||||
rand_method: Callable[..., Any]=jtu.rand_default,
|
||||
|
@ -1738,7 +1738,8 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
sparse_fun = partial(sparse.bcoo_reduce_sum, axes=axes)
|
||||
dense_fun = partial(lambda x: x.sum(axes))
|
||||
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker)
|
||||
tol = {np.float64: 1E-14}
|
||||
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker)
|
||||
|
||||
@ -1901,7 +1902,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
sprng(rhs_shape, rhs_dtype, n_batch=rhs_n_batch)]
|
||||
|
||||
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
||||
np.float32: 1E-6, np.complex64: 1E-6}
|
||||
np.float32: 1E-5, np.complex64: 1E-5}
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol)
|
||||
@ -2119,6 +2120,34 @@ class BCSRTest(sptu.SparseTestCase):
|
||||
args_maker_todense = lambda: [data, indices, indptr]
|
||||
self._CompileAndCheck(todense, args_maker_todense)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch)
|
||||
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for n_batch in range(len(shape) - 1)
|
||||
],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
)
|
||||
def test_bcsr_bcoo_round_trip(self, shape, n_batch, dtype):
|
||||
n_sparse = 2
|
||||
n_dense = len(shape) - n_sparse - n_batch
|
||||
rng = self.rng()
|
||||
sprng = sptu.rand_bcsr(rng, n_batch=n_batch, n_dense=n_dense)
|
||||
|
||||
M_bcsr = sprng(shape, dtype)
|
||||
self.assertIsInstance(M_bcsr, sparse.BCSR)
|
||||
|
||||
M_dense = M_bcsr.todense()
|
||||
M_bcoo = M_bcsr.to_bcoo()
|
||||
self.assertIsInstance(M_bcoo, sparse.BCOO)
|
||||
self.assertAllClose(M_dense, M_bcoo.todense())
|
||||
|
||||
M_bcsr2 = sparse.BCSR.from_bcoo(M_bcoo)
|
||||
self.assertAllClose(M_dense, M_bcsr2.todense())
|
||||
self.assertArraysEqual(M_bcsr.indptr, M_bcsr2.indptr)
|
||||
|
||||
# TODO(jakevdp): This will only be true in general when M_bcsr.indices is sorted.
|
||||
# self.assertSparseArraysEquivalent(M_bcsr, M_bcsr2)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch)
|
||||
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
@ -2541,7 +2570,7 @@ class SparseObjectTest(sptu.SparseTestCase):
|
||||
_, bcoo_indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch,
|
||||
n_dense=n_dense)
|
||||
|
||||
bcoo_to_bcsr = partial(sparse_bcoo._bcoo_to_bcsr, shape=shape)
|
||||
bcoo_to_bcsr = partial(sparse_bcsr._bcoo_to_bcsr, shape=shape)
|
||||
|
||||
args_maker_bcoo_to_bcsr = lambda: [bcoo_indices]
|
||||
self._CompileAndCheck(bcoo_to_bcsr, args_maker_bcoo_to_bcsr)
|
||||
|
Loading…
x
Reference in New Issue
Block a user