mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] Add BCSR primitive bcsr_extract.
PiperOrigin-RevId: 482530210
This commit is contained in:
parent
1a0affddd8
commit
2d563bf0a2
@ -218,6 +218,8 @@ from jax.experimental.sparse.bcoo import (
|
||||
)
|
||||
|
||||
from jax.experimental.sparse.bcsr import (
|
||||
bcsr_extract as bcsr_extract,
|
||||
bcsr_extract_p as bcsr_extract_p,
|
||||
bcsr_fromdense as bcsr_fromdense,
|
||||
bcsr_fromdense_p as bcsr_fromdense_p,
|
||||
bcsr_todense as bcsr_todense,
|
||||
|
@ -233,6 +233,43 @@ mlir.register_lowering(bcsr_todense_p, mlir.lower_fun(
|
||||
_bcsr_todense_impl, multiple_results=False))
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# bcsr_extract
|
||||
bcsr_extract_p = core.Primitive('bcsr_extract')
|
||||
|
||||
|
||||
def bcsr_extract(indices, indptr, mat):
|
||||
"""Extract values from a dense matrix at given BCSR (indices, indptr).
|
||||
|
||||
Args:
|
||||
indices: An ndarray; see BCSR indices.
|
||||
indptr: An ndarray; see BCSR indptr.
|
||||
mat: A dense matrix.
|
||||
|
||||
Returns:
|
||||
An ndarray; see BCSR data.
|
||||
"""
|
||||
return bcsr_extract_p.bind(indices, indptr, mat)
|
||||
|
||||
|
||||
@bcsr_extract_p.def_impl
|
||||
def _bcsr_extract_impl(indices, indptr, mat):
|
||||
mat = jnp.asarray(mat)
|
||||
bcoo_indices = _bcsr_to_bcoo(indices, indptr, shape=mat.shape)
|
||||
return bcoo.bcoo_extract(bcoo_indices, mat)
|
||||
|
||||
|
||||
@bcsr_extract_p.def_abstract_eval
|
||||
def _bcsr_extract_abstract_eval(indices, indptr, mat):
|
||||
n_batch, n_dense, nse = _validate_bcsr_indices(indices, indptr, mat.shape)
|
||||
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
|
||||
return core.ShapedArray(out_shape, mat.dtype)
|
||||
|
||||
|
||||
mlir.register_lowering(bcsr_extract_p, mlir.lower_fun(
|
||||
_bcsr_extract_impl, multiple_results=False))
|
||||
|
||||
|
||||
class BCSR(JAXSparse):
|
||||
"""Experimental batched CSR matrix implemented in JAX."""
|
||||
|
||||
|
@ -2269,6 +2269,26 @@ class BCSRTest(jtu.JaxTestCase):
|
||||
args_maker_todense = lambda: [data, indices, indptr]
|
||||
self._CompileAndCheck(todense, args_maker_todense)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch)
|
||||
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for n_batch in range(len(shape) - 1)
|
||||
],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
)
|
||||
def test_bcsr_extract(self, shape, dtype, n_batch):
|
||||
n_dense = len(shape) - n_batch - 2
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
|
||||
n_dense=n_dense)
|
||||
data, indices, indptr = sparse_bcsr._bcsr_fromdense(
|
||||
M, nse=nse, n_batch=n_batch, n_dense=n_dense)
|
||||
data2 = sparse.bcsr_extract(indices, indptr, M)
|
||||
self.assertArraysEqual(data, data2)
|
||||
args_maker_bcsr_extract = lambda: [indices, indptr, M]
|
||||
self._CompileAndCheck(sparse.bcsr_extract, args_maker_bcsr_extract)
|
||||
|
||||
|
||||
class SparseGradTest(jtu.JaxTestCase):
|
||||
def test_sparse_grad(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user