[sparse] Add BCSR primitive bcsr_extract.

PiperOrigin-RevId: 482530210
This commit is contained in:
Tianjian Lu 2022-10-20 10:29:48 -07:00 committed by jax authors
parent 1a0affddd8
commit 2d563bf0a2
3 changed files with 59 additions and 0 deletions

View File

@ -218,6 +218,8 @@ from jax.experimental.sparse.bcoo import (
)
from jax.experimental.sparse.bcsr import (
bcsr_extract as bcsr_extract,
bcsr_extract_p as bcsr_extract_p,
bcsr_fromdense as bcsr_fromdense,
bcsr_fromdense_p as bcsr_fromdense_p,
bcsr_todense as bcsr_todense,

View File

@ -233,6 +233,43 @@ mlir.register_lowering(bcsr_todense_p, mlir.lower_fun(
_bcsr_todense_impl, multiple_results=False))
#--------------------------------------------------------------------
# bcsr_extract
bcsr_extract_p = core.Primitive('bcsr_extract')
def bcsr_extract(indices, indptr, mat):
"""Extract values from a dense matrix at given BCSR (indices, indptr).
Args:
indices: An ndarray; see BCSR indices.
indptr: An ndarray; see BCSR indptr.
mat: A dense matrix.
Returns:
An ndarray; see BCSR data.
"""
return bcsr_extract_p.bind(indices, indptr, mat)
@bcsr_extract_p.def_impl
def _bcsr_extract_impl(indices, indptr, mat):
mat = jnp.asarray(mat)
bcoo_indices = _bcsr_to_bcoo(indices, indptr, shape=mat.shape)
return bcoo.bcoo_extract(bcoo_indices, mat)
@bcsr_extract_p.def_abstract_eval
def _bcsr_extract_abstract_eval(indices, indptr, mat):
n_batch, n_dense, nse = _validate_bcsr_indices(indices, indptr, mat.shape)
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
return core.ShapedArray(out_shape, mat.dtype)
mlir.register_lowering(bcsr_extract_p, mlir.lower_fun(
_bcsr_extract_impl, multiple_results=False))
class BCSR(JAXSparse):
"""Experimental batched CSR matrix implemented in JAX."""

View File

@ -2269,6 +2269,26 @@ class BCSRTest(jtu.JaxTestCase):
args_maker_todense = lambda: [data, indices, indptr]
self._CompileAndCheck(todense, args_maker_todense)
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch)
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for n_batch in range(len(shape) - 1)
],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
def test_bcsr_extract(self, shape, dtype, n_batch):
n_dense = len(shape) - n_batch - 2
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
n_dense=n_dense)
data, indices, indptr = sparse_bcsr._bcsr_fromdense(
M, nse=nse, n_batch=n_batch, n_dense=n_dense)
data2 = sparse.bcsr_extract(indices, indptr, M)
self.assertArraysEqual(data, data2)
args_maker_bcsr_extract = lambda: [indices, indptr, M]
self._CompileAndCheck(sparse.bcsr_extract, args_maker_bcsr_extract)
class SparseGradTest(jtu.JaxTestCase):
def test_sparse_grad(self):