[sparse] add BCSR.to_bcoo and from_bcoo methods

This commit is contained in:
Jake VanderPlas 2023-01-30 10:42:05 -08:00
parent 65ef487a82
commit 5b0329daa8
4 changed files with 83 additions and 31 deletions

View File

@ -151,25 +151,6 @@ def _validate_bcoo_indices(indices: Buffer, shape: Sequence[int]) -> BCOOPropert
return BCOOProperties(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense, nse=nse)
def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int],
index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array]:
"""Given BCOO (indices), return BCSR (indices, indptr)."""
n_batch, n_sparse, _, _ = _validate_bcoo_indices(indices, shape)
if n_sparse != 2:
raise ValueError("Must have 2 sparse dimensions to be converted to BCSR.")
n_rows = shape[n_batch]
@partial(nfold_vmap, N=n_batch, broadcasted=False)
def get_ptr(i):
indptr = jnp.zeros(n_rows + 1, index_dtype)
return indptr.at[1:].set(jnp.cumsum(
jnp.bincount(i, length=n_rows).astype(index_dtype)))
return indices[..., 1], get_ptr(indices[..., 0])
#----------------------------------------------------------------------
# bcoo_todense

View File

@ -15,6 +15,7 @@
"""BCSR (Bached compressed row) matrix object and associated primitives."""
from __future__ import annotations
from functools import partial
import operator
from typing import NamedTuple, Optional, Sequence, Tuple, Union
@ -94,6 +95,28 @@ def _bcsr_to_bcoo(indices: jnp.ndarray, indptr: jnp.ndarray, *,
return jnp.stack(csr_to_coo(indices, indptr), axis=indices.ndim)
def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int],
index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array]:
"""Given BCOO (indices), return BCSR (indices, indptr).
Note: this assumes that ``indices`` are lexicographically sorted within each batch.
"""
n_batch, n_sparse, _, _ = bcoo._validate_bcoo_indices(indices, shape)
if n_sparse != 2:
raise ValueError("Must have 2 sparse dimensions to be converted to BCSR.")
n_rows = shape[n_batch]
@partial(nfold_vmap, N=n_batch, broadcasted=False)
def get_ptr(i):
indptr = jnp.zeros(n_rows + 1, index_dtype)
return indptr.at[1:].set(jnp.cumsum(
jnp.bincount(i, length=n_rows).astype(index_dtype)))
return indices[..., 1], get_ptr(indices[..., 0])
#--------------------------------------------------------------------
# bcsr_fromdense
bcsr_fromdense_p = core.Primitive('bcsr_fromdense')
@ -165,7 +188,7 @@ def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.")
bcoo_mat = bcoo.bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype,
n_dense=n_dense, n_batch=n_batch)
indices, indptr = bcoo._bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape)
indices, indptr = _bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape)
return bcoo_mat.data, indices, indptr
@ -539,6 +562,19 @@ class BCSR(JAXSparse):
"""Create a dense version of the array."""
return bcsr_todense(self)
def to_bcoo(self) -> bcoo.BCOO:
coo_indices = _bcsr_to_bcoo(self.indices, self.indptr, shape=self.shape)
return bcoo.BCOO((self.data, coo_indices), shape=self.shape)
@classmethod
def from_bcoo(cls, arr: bcoo.BCOO) -> BCSR:
if arr.n_sparse != 2:
raise NotImplementedError(f"BSCR.from_bcoo requires n_sparse=2; got {arr.n_sparse=}")
if not arr.indices_sorted:
arr = arr.sort_indices()
indices, indptr = _bcoo_to_bcsr(arr.indices, shape=arr.shape)
return cls((arr.data, indices, indptr), shape=arr.shape)
@classmethod
def from_scipy_sparse(cls, mat, *, index_dtype=None, n_dense=0, n_batch=0):
"""Create a BCSR array from a :mod:`scipy.sparse` array."""

View File

@ -27,6 +27,7 @@ from jax import tree_util
from jax.util import safe_zip, split_list
from jax.experimental import sparse
from jax.experimental.sparse import bcoo as sparse_bcoo
from jax.experimental.sparse import bcsr as sparse_bcsr
import jax.numpy as jnp
@ -136,18 +137,23 @@ def _rand_sparse(shape: Sequence[int], dtype: DTypeLike, *,
if 0 <= nse < 1:
nse = int(np.ceil(nse * np.prod(sparse_shape)))
data_rng = rand_method(rng)
index_shape = (*batch_shape, nse, n_sparse)
data_shape = (*batch_shape, nse, *dense_shape)
bcoo_indices = jnp.array(
rng.randint(0, sparse_shape, size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
data = jnp.array(data_rng(data_shape, dtype))
if sparse_format == 'bcoo':
return sparse.BCOO((data, bcoo_indices), shape=shape)
bcsr_indices, bcsr_indptr = sparse_bcoo._bcoo_to_bcsr(
bcoo_indices, shape=shape)
return sparse.BCSR((data, bcsr_indices, bcsr_indptr), shape=shape)
index_shape = (*batch_shape, nse, n_sparse)
indices = jnp.array(
rng.randint(0, sparse_shape, size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
return sparse.BCOO((data, indices), shape=shape)
else:
index_shape = (*batch_shape, nse)
indptr_shape = (*batch_shape, sparse_shape[0] + 1)
indices = jnp.array(
rng.randint(0, sparse_shape[1], size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
indptr = jnp.sort(
rng.randint(0, nse + 1, size=indptr_shape, dtype=np.int32), axis=-1) # type: ignore[call-overload]
indptr = indptr.at[..., 0].set(0)
return sparse.BCSR((data, indices, indptr), shape=shape)
def rand_bcoo(rng: np.random.RandomState,
rand_method: Callable[..., Any]=jtu.rand_default,

View File

@ -1738,7 +1738,8 @@ class BCOOTest(sptu.SparseTestCase):
sparse_fun = partial(sparse.bcoo_reduce_sum, axes=axes)
dense_fun = partial(lambda x: x.sum(axes))
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker)
tol = {np.float64: 1E-14}
self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
if jnp.issubdtype(dtype, jnp.floating):
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker)
@ -1901,7 +1902,7 @@ class BCOOTest(sptu.SparseTestCase):
sprng(rhs_shape, rhs_dtype, n_batch=rhs_n_batch)]
tol = {np.float64: 1E-13, np.complex128: 1E-13,
np.float32: 1E-6, np.complex64: 1E-6}
np.float32: 1E-5, np.complex64: 1E-5}
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol)
@ -2119,6 +2120,34 @@ class BCSRTest(sptu.SparseTestCase):
args_maker_todense = lambda: [data, indices, indptr]
self._CompileAndCheck(todense, args_maker_todense)
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch)
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for n_batch in range(len(shape) - 1)
],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
def test_bcsr_bcoo_round_trip(self, shape, n_batch, dtype):
n_sparse = 2
n_dense = len(shape) - n_sparse - n_batch
rng = self.rng()
sprng = sptu.rand_bcsr(rng, n_batch=n_batch, n_dense=n_dense)
M_bcsr = sprng(shape, dtype)
self.assertIsInstance(M_bcsr, sparse.BCSR)
M_dense = M_bcsr.todense()
M_bcoo = M_bcsr.to_bcoo()
self.assertIsInstance(M_bcoo, sparse.BCOO)
self.assertAllClose(M_dense, M_bcoo.todense())
M_bcsr2 = sparse.BCSR.from_bcoo(M_bcoo)
self.assertAllClose(M_dense, M_bcsr2.todense())
self.assertArraysEqual(M_bcsr.indptr, M_bcsr2.indptr)
# TODO(jakevdp): This will only be true in general when M_bcsr.indices is sorted.
# self.assertSparseArraysEquivalent(M_bcsr, M_bcsr2)
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch)
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
@ -2541,7 +2570,7 @@ class SparseObjectTest(sptu.SparseTestCase):
_, bcoo_indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch,
n_dense=n_dense)
bcoo_to_bcsr = partial(sparse_bcoo._bcoo_to_bcsr, shape=shape)
bcoo_to_bcsr = partial(sparse_bcsr._bcoo_to_bcsr, shape=shape)
args_maker_bcoo_to_bcsr = lambda: [bcoo_indices]
self._CompileAndCheck(bcoo_to_bcsr, args_maker_bcoo_to_bcsr)