[sparse] implement bcsr_concatenate

This commit is contained in:
Jake VanderPlas 2023-02-15 12:38:57 -08:00
parent 9b288e9ab9
commit f3e5024787
5 changed files with 58 additions and 2 deletions

View File

@ -234,6 +234,7 @@ from jax.experimental.sparse.bcoo import (
from jax.experimental.sparse.bcsr import (
bcsr_broadcast_in_dim as bcsr_broadcast_in_dim,
bcsr_concatenate as bcsr_concatenate,
bcsr_dot_general as bcsr_dot_general,
bcsr_dot_general_p as bcsr_dot_general_p,
bcsr_extract as bcsr_extract,

View File

@ -703,6 +703,22 @@ def bcsr_broadcast_in_dim(mat: BCSR, *, shape: Shape, broadcast_dimensions: Sequ
mat.to_bcoo(), shape=shape, broadcast_dimensions=broadcast_dimensions)
return BCSR.from_bcoo(result_bcoo)
def bcsr_concatenate(operands: Sequence[BCSR], *, dimension: int) -> BCSR:
"""Sparse implementation of :func:`jax.lax.concatenate`
Args:
operands : Sequence of BCSR arrays to concatenate. The arrays must have equal
shapes, except in the `dimension` axis. Additionally, the arrays must have
have equivalent batch, sparse, and dense dimensions.
dimension : Positive integer specifying the dimension along which to concatenate
the arrays. The dimension must be among batch or sparse dimensions of the input;
concatenation along dense dimensions is not supported.
Returns:
A BCSR array containing the concatenation of the inputs.
"""
return BCSR.from_bcoo(
bcoo.bcoo_concatenate([mat.to_bcoo() for mat in operands], dimension=dimension))
@tree_util.register_pytree_node_class
class BCSR(JAXSparse):

View File

@ -554,6 +554,7 @@ for prim, bcoo_impl in _BCOO_STANDARD_PRIMITIVES.items():
_BCSR_STANDARD_PRIMITIVES = {
lax.dot_general_p: sparse.bcsr_dot_general,
lax.broadcast_in_dim_p: sparse.bcsr_broadcast_in_dim,
lax.concatenate_p: lambda *a, **k: sparse.bcsr_concatenate(a, **k),
}
for prim, bcsr_impl in _BCSR_STANDARD_PRIMITIVES.items():

View File

@ -2249,6 +2249,24 @@ class BCSRTest(sptu.SparseTestCase):
self.assertEqual(xsp[:, None].n_batch, xsp.n_batch + 1)
self.assertArraysEqual(xsp[:, None].todense(), x[:, None])
@jtu.sample_product(
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, dimension=dimension)
for shape in [(3, 5), (3, 5, 4)]
for layout in iter_sparse_layouts(shape)
for dimension in range(len(shape) - layout.n_dense) # Concatenation of dense dimensions not implemented.
],
dtype=all_dtypes,
)
def test_bcsr_concatenate(self, shape, dtype, n_batch, n_dense, dimension):
sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense)
args_maker = lambda: [[sprng(shape, dtype) for i in range(3)]]
dense_func = partial(lax.concatenate, dimension=dimension)
sparse_func = partial(sparse.bcoo_concatenate, dimension=dimension)
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
if jnp.issubdtype(dtype, jnp.floating):
self._CheckGradsSparse(dense_func, sparse_func, args_maker)
class SparseGradTest(sptu.SparseTestCase):
@jtu.sample_product(has_aux=[True, False])

View File

@ -326,9 +326,29 @@ class SparsifyTest(jtu.JaxTestCase):
([(2, 4), (4,), (3, 4)], "vstack", 0),
([(1, 4), (4,), (1, 4)], "vstack", 0),
]
],
]
)
def testSparseConcatenate(self, shapes, func, n_batch):
def testSparseConcatenateBCOO(self, shapes, func, n_batch):
f = self.sparsify(getattr(jnp, func))
rng = jtu.rand_some_zero(self.rng())
arrs = [rng(shape, 'int32') for shape in shapes]
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
self.assertArraysEqual(f(arrs), f(sparrs).todense())
@jtu.sample_product(
[dict(shapes=shapes, func=func, n_batch=n_batch)
for shapes, func, n_batch in [
([(2, 4), (2, 4)], "stack", 0),
([(2, 4), (3, 4)], "vstack", 0),
([(2, 4), (2, 5)], "hstack", 0),
([(2, 4), (3, 4)], "vstack", 1),
([(2, 4), (2, 5)], "hstack", 1),
([(2, 4), (3, 4)], "vstack", 2),
([(2, 4), (2, 5)], "hstack", 2),
]
]
)
def testSparseConcatenateBCSR(self, shapes, func, n_batch):
f = self.sparsify(getattr(jnp, func))
rng = jtu.rand_some_zero(self.rng())
arrs = [rng(shape, 'int32') for shape in shapes]