mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
sparse] BCSR batching rule.
[Co-authored-by: Jake Vanderplas: <vanderplas@google.com>
This commit is contained in:
parent
dc0d7ba368
commit
332fced0cc
@ -179,6 +179,21 @@ def _bcoo_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
|
||||
core.ShapedArray(indptr_shape, index_dtype))
|
||||
|
||||
|
||||
def _bcsr_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch,
|
||||
n_dense, index_dtype):
|
||||
M, = batched_args
|
||||
if batch_dims != (0,):
|
||||
raise NotImplementedError(f"batch_dims={batch_dims}")
|
||||
new_n_batch = n_batch + 1
|
||||
n_sparse = M.ndim - n_dense - new_n_batch
|
||||
if n_sparse != 2:
|
||||
raise ValueError("_bcsr_fromdense_batching_rule: must have 2 sparse "
|
||||
f"dimensions but {n_sparse} is given.")
|
||||
return _bcsr_fromdense(M, nse=nse, n_batch=new_n_batch, n_dense=n_dense,
|
||||
index_dtype=index_dtype), (0, 0, 0)
|
||||
|
||||
|
||||
batching.primitive_batchers[bcsr_fromdense_p] = _bcsr_fromdense_batching_rule
|
||||
mlir.register_lowering(bcsr_fromdense_p, mlir.lower_fun(
|
||||
_bcsr_fromdense_impl, multiple_results=True))
|
||||
|
||||
@ -229,6 +244,20 @@ def _bcsr_todense_abstract_eval(data, indices, indptr, *, shape):
|
||||
return core.ShapedArray(shape, data.dtype)
|
||||
|
||||
|
||||
def _bcsr_todense_batching_rule(batched_args, batch_dims, *, shape):
|
||||
data, indices, indptr = batched_args
|
||||
if any(b not in [0, None] for b in batch_dims):
|
||||
raise NotImplementedError(f"batch_dims={batch_dims}. "
|
||||
"Only 0 and None are supported.")
|
||||
if batch_dims[0] is None:
|
||||
data = data[None, ...]
|
||||
if batch_dims[1] is None:
|
||||
indices = indices[None, ...]
|
||||
if batch_dims[2] is None:
|
||||
indptr = indptr[None, ...]
|
||||
return _bcsr_todense(data, indices, indptr, shape=shape), 0
|
||||
|
||||
batching.primitive_batchers[bcsr_todense_p] = _bcsr_todense_batching_rule
|
||||
mlir.register_lowering(bcsr_todense_p, mlir.lower_fun(
|
||||
_bcsr_todense_impl, multiple_results=False))
|
||||
|
||||
|
@ -2308,6 +2308,44 @@ class BCSRTest(jtu.JaxTestCase):
|
||||
args_maker_todense = lambda: [data, indices, indptr]
|
||||
self._CompileAndCheck(todense, args_maker_todense)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch)
|
||||
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for n_batch in range(len(shape) - 1)
|
||||
],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
)
|
||||
def test_bcsr_dense_round_trip_batched(self, shape, dtype, n_batch):
|
||||
n_sparse = 2
|
||||
n_dense = len(shape) - n_sparse - n_batch
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
|
||||
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
|
||||
n_dense=n_dense)
|
||||
|
||||
fromdense = partial(sparse_bcsr._bcsr_fromdense, nse=nse, n_batch=0,
|
||||
n_dense=n_dense)
|
||||
todense = partial(sparse_bcsr._bcsr_todense, shape=shape)
|
||||
|
||||
for _ in range(n_batch):
|
||||
fromdense = jax.vmap(fromdense)
|
||||
todense = jax.vmap(todense)
|
||||
|
||||
data, indices, indptr = fromdense(M)
|
||||
|
||||
self.assertEqual(data.dtype, dtype)
|
||||
self.assertEqual(data.shape,
|
||||
shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:])
|
||||
self.assertEqual(indices.dtype, jnp.int32)
|
||||
self.assertEqual(indices.shape, shape[:n_batch] + (nse,))
|
||||
self.assertEqual(indptr.dtype, jnp.int32)
|
||||
self.assertEqual(indptr.shape, shape[:n_batch] + (shape[n_batch] + 1,))
|
||||
|
||||
self.assertArraysEqual(M, todense(data, indices, indptr))
|
||||
args_maker_todense = lambda: [data, indices, indptr]
|
||||
self._CompileAndCheck(todense, args_maker_todense)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch)
|
||||
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user