Merge pull request #13887 from jakevdp:bcoo-dot-general-batch

PiperOrigin-RevId: 501637620
This commit is contained in:
jax authors 2023-01-12 12:34:39 -08:00
commit ba506cbfe2
2 changed files with 11 additions and 19 deletions

View File

@ -991,22 +991,14 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num
return lhs_data, lhs_indices, lax.transpose(result, out_axes)
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_spinfo: SparseInfo):
lhs_data, lhs_indices, rhs = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
lhs_data = lhs_data[None]
batch_dims[0] = 0
if batch_dims[1] is None:
lhs_indices = lhs_indices[None]
batch_dims[1] = 0
# TODO: handle different batchings between lhs_data and lhs_indices?
assert batch_dims[0] == batch_dims[1] == 0
_, _, rhs = batched_args
_, _, rhs_bdim = batch_dims
new_lhs_data, new_lhs_indices, new_lhs_spinfo = _bcoo_batch_dims_to_front(
batched_args[:2], batch_dims[:2], lhs_spinfo,
batch_size=None if rhs_bdim is None else rhs.shape[rhs_bdim])
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(len(lhs_spinfo.shape), rhs.ndim), (batch_dims[0], batch_dims[2]), dimension_numbers)
new_shape = (batch_size, *lhs_spinfo.shape)
batched_out = _bcoo_dot_general(lhs_data, lhs_indices, rhs, lhs_spinfo=SparseInfo(new_shape),
(len(lhs_spinfo.shape), rhs.ndim), (0, rhs_bdim), dimension_numbers)
batched_out = _bcoo_dot_general(new_lhs_data, new_lhs_indices, rhs, lhs_spinfo=new_lhs_spinfo,
dimension_numbers=new_dimension_numbers)
return batched_out, result_batch_dim

View File

@ -734,10 +734,8 @@ class BCOOTest(sptu.SparseTestCase):
self.assertArraysEqual(M.todense(), jnp.empty(shape, dtype))
@jtu.sample_product(
[dict(n_batch=n_batch, n_dense=n_dense)
for n_batch in range(3)
for n_dense in range(3 - n_batch)
],
[dict(n_batch=layout.n_batch, n_dense=layout.n_dense)
for layout in iter_sparse_layouts((3, 3))],
N=[3, 5],
M=[None, 4],
k=[-3, -1, 0, 2, 4],
@ -955,6 +953,8 @@ class BCOOTest(sptu.SparseTestCase):
# Dense dimensions not yet fully supported in reverse mode.
modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev']
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol)
self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol,
bdims=self._random_bdims(props.n_batch, len(props.rhs_shape)))
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")