sparse] BCSR batching rule.

[Co-authored-by: Jake Vanderplas: <vanderplas@google.com>
This commit is contained in:
Tianjian Lu 2022-11-10 13:55:35 -08:00
parent dc0d7ba368
commit 332fced0cc
2 changed files with 67 additions and 0 deletions

View File

@ -179,6 +179,21 @@ def _bcoo_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
core.ShapedArray(indptr_shape, index_dtype))
def _bcsr_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch,
n_dense, index_dtype):
M, = batched_args
if batch_dims != (0,):
raise NotImplementedError(f"batch_dims={batch_dims}")
new_n_batch = n_batch + 1
n_sparse = M.ndim - n_dense - new_n_batch
if n_sparse != 2:
raise ValueError("_bcsr_fromdense_batching_rule: must have 2 sparse "
f"dimensions but {n_sparse} is given.")
return _bcsr_fromdense(M, nse=nse, n_batch=new_n_batch, n_dense=n_dense,
index_dtype=index_dtype), (0, 0, 0)
batching.primitive_batchers[bcsr_fromdense_p] = _bcsr_fromdense_batching_rule
mlir.register_lowering(bcsr_fromdense_p, mlir.lower_fun(
_bcsr_fromdense_impl, multiple_results=True))
@ -229,6 +244,20 @@ def _bcsr_todense_abstract_eval(data, indices, indptr, *, shape):
return core.ShapedArray(shape, data.dtype)
def _bcsr_todense_batching_rule(batched_args, batch_dims, *, shape):
data, indices, indptr = batched_args
if any(b not in [0, None] for b in batch_dims):
raise NotImplementedError(f"batch_dims={batch_dims}. "
"Only 0 and None are supported.")
if batch_dims[0] is None:
data = data[None, ...]
if batch_dims[1] is None:
indices = indices[None, ...]
if batch_dims[2] is None:
indptr = indptr[None, ...]
return _bcsr_todense(data, indices, indptr, shape=shape), 0
batching.primitive_batchers[bcsr_todense_p] = _bcsr_todense_batching_rule
mlir.register_lowering(bcsr_todense_p, mlir.lower_fun(
_bcsr_todense_impl, multiple_results=False))

View File

@ -2308,6 +2308,44 @@ class BCSRTest(jtu.JaxTestCase):
args_maker_todense = lambda: [data, indices, indptr]
self._CompileAndCheck(todense, args_maker_todense)
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch)
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for n_batch in range(len(shape) - 1)
],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
def test_bcsr_dense_round_trip_batched(self, shape, dtype, n_batch):
n_sparse = 2
n_dense = len(shape) - n_sparse - n_batch
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
n_dense=n_dense)
fromdense = partial(sparse_bcsr._bcsr_fromdense, nse=nse, n_batch=0,
n_dense=n_dense)
todense = partial(sparse_bcsr._bcsr_todense, shape=shape)
for _ in range(n_batch):
fromdense = jax.vmap(fromdense)
todense = jax.vmap(todense)
data, indices, indptr = fromdense(M)
self.assertEqual(data.dtype, dtype)
self.assertEqual(data.shape,
shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:])
self.assertEqual(indices.dtype, jnp.int32)
self.assertEqual(indices.shape, shape[:n_batch] + (nse,))
self.assertEqual(indptr.dtype, jnp.int32)
self.assertEqual(indptr.shape, shape[:n_batch] + (shape[n_batch] + 1,))
self.assertArraysEqual(M, todense(data, indices, indptr))
args_maker_todense = lambda: [data, indices, indptr]
self._CompileAndCheck(todense, args_maker_todense)
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch)
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]