mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[sparse] implement bcsr_concatenate
This commit is contained in:
parent
9b288e9ab9
commit
f3e5024787
@ -234,6 +234,7 @@ from jax.experimental.sparse.bcoo import (
|
||||
|
||||
from jax.experimental.sparse.bcsr import (
|
||||
bcsr_broadcast_in_dim as bcsr_broadcast_in_dim,
|
||||
bcsr_concatenate as bcsr_concatenate,
|
||||
bcsr_dot_general as bcsr_dot_general,
|
||||
bcsr_dot_general_p as bcsr_dot_general_p,
|
||||
bcsr_extract as bcsr_extract,
|
||||
|
@ -703,6 +703,22 @@ def bcsr_broadcast_in_dim(mat: BCSR, *, shape: Shape, broadcast_dimensions: Sequ
|
||||
mat.to_bcoo(), shape=shape, broadcast_dimensions=broadcast_dimensions)
|
||||
return BCSR.from_bcoo(result_bcoo)
|
||||
|
||||
def bcsr_concatenate(operands: Sequence[BCSR], *, dimension: int) -> BCSR:
|
||||
"""Sparse implementation of :func:`jax.lax.concatenate`
|
||||
|
||||
Args:
|
||||
operands : Sequence of BCSR arrays to concatenate. The arrays must have equal
|
||||
shapes, except in the `dimension` axis. Additionally, the arrays must have
|
||||
have equivalent batch, sparse, and dense dimensions.
|
||||
dimension : Positive integer specifying the dimension along which to concatenate
|
||||
the arrays. The dimension must be among batch or sparse dimensions of the input;
|
||||
concatenation along dense dimensions is not supported.
|
||||
|
||||
Returns:
|
||||
A BCSR array containing the concatenation of the inputs.
|
||||
"""
|
||||
return BCSR.from_bcoo(
|
||||
bcoo.bcoo_concatenate([mat.to_bcoo() for mat in operands], dimension=dimension))
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class BCSR(JAXSparse):
|
||||
|
@ -554,6 +554,7 @@ for prim, bcoo_impl in _BCOO_STANDARD_PRIMITIVES.items():
|
||||
_BCSR_STANDARD_PRIMITIVES = {
|
||||
lax.dot_general_p: sparse.bcsr_dot_general,
|
||||
lax.broadcast_in_dim_p: sparse.bcsr_broadcast_in_dim,
|
||||
lax.concatenate_p: lambda *a, **k: sparse.bcsr_concatenate(a, **k),
|
||||
}
|
||||
|
||||
for prim, bcsr_impl in _BCSR_STANDARD_PRIMITIVES.items():
|
||||
|
@ -2249,6 +2249,24 @@ class BCSRTest(sptu.SparseTestCase):
|
||||
self.assertEqual(xsp[:, None].n_batch, xsp.n_batch + 1)
|
||||
self.assertArraysEqual(xsp[:, None].todense(), x[:, None])
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, dimension=dimension)
|
||||
for shape in [(3, 5), (3, 5, 4)]
|
||||
for layout in iter_sparse_layouts(shape)
|
||||
for dimension in range(len(shape) - layout.n_dense) # Concatenation of dense dimensions not implemented.
|
||||
],
|
||||
dtype=all_dtypes,
|
||||
)
|
||||
def test_bcsr_concatenate(self, shape, dtype, n_batch, n_dense, dimension):
|
||||
sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense)
|
||||
args_maker = lambda: [[sprng(shape, dtype) for i in range(3)]]
|
||||
dense_func = partial(lax.concatenate, dimension=dimension)
|
||||
sparse_func = partial(sparse.bcoo_concatenate, dimension=dimension)
|
||||
|
||||
self._CheckAgainstDense(dense_func, sparse_func, args_maker)
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
self._CheckGradsSparse(dense_func, sparse_func, args_maker)
|
||||
|
||||
|
||||
class SparseGradTest(sptu.SparseTestCase):
|
||||
@jtu.sample_product(has_aux=[True, False])
|
||||
|
@ -326,9 +326,29 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
([(2, 4), (4,), (3, 4)], "vstack", 0),
|
||||
([(1, 4), (4,), (1, 4)], "vstack", 0),
|
||||
]
|
||||
],
|
||||
]
|
||||
)
|
||||
def testSparseConcatenate(self, shapes, func, n_batch):
|
||||
def testSparseConcatenateBCOO(self, shapes, func, n_batch):
|
||||
f = self.sparsify(getattr(jnp, func))
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
arrs = [rng(shape, 'int32') for shape in shapes]
|
||||
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
|
||||
self.assertArraysEqual(f(arrs), f(sparrs).todense())
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shapes=shapes, func=func, n_batch=n_batch)
|
||||
for shapes, func, n_batch in [
|
||||
([(2, 4), (2, 4)], "stack", 0),
|
||||
([(2, 4), (3, 4)], "vstack", 0),
|
||||
([(2, 4), (2, 5)], "hstack", 0),
|
||||
([(2, 4), (3, 4)], "vstack", 1),
|
||||
([(2, 4), (2, 5)], "hstack", 1),
|
||||
([(2, 4), (3, 4)], "vstack", 2),
|
||||
([(2, 4), (2, 5)], "hstack", 2),
|
||||
]
|
||||
]
|
||||
)
|
||||
def testSparseConcatenateBCSR(self, shapes, func, n_batch):
|
||||
f = self.sparsify(getattr(jnp, func))
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
arrs = [rng(shape, 'int32') for shape in shapes]
|
||||
|
Loading…
x
Reference in New Issue
Block a user