mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13887 from jakevdp:bcoo-dot-general-batch
PiperOrigin-RevId: 501637620
This commit is contained in:
commit
ba506cbfe2
@ -991,22 +991,14 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num
|
||||
return lhs_data, lhs_indices, lax.transpose(result, out_axes)
|
||||
|
||||
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_spinfo: SparseInfo):
|
||||
lhs_data, lhs_indices, rhs = batched_args
|
||||
batch_dims = list(batch_dims)
|
||||
batch_size = max(0 if dim is None else arg.shape[dim]
|
||||
for arg, dim in zip(batched_args, batch_dims))
|
||||
if batch_dims[0] is None:
|
||||
lhs_data = lhs_data[None]
|
||||
batch_dims[0] = 0
|
||||
if batch_dims[1] is None:
|
||||
lhs_indices = lhs_indices[None]
|
||||
batch_dims[1] = 0
|
||||
# TODO: handle different batchings between lhs_data and lhs_indices?
|
||||
assert batch_dims[0] == batch_dims[1] == 0
|
||||
_, _, rhs = batched_args
|
||||
_, _, rhs_bdim = batch_dims
|
||||
new_lhs_data, new_lhs_indices, new_lhs_spinfo = _bcoo_batch_dims_to_front(
|
||||
batched_args[:2], batch_dims[:2], lhs_spinfo,
|
||||
batch_size=None if rhs_bdim is None else rhs.shape[rhs_bdim])
|
||||
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
|
||||
(len(lhs_spinfo.shape), rhs.ndim), (batch_dims[0], batch_dims[2]), dimension_numbers)
|
||||
new_shape = (batch_size, *lhs_spinfo.shape)
|
||||
batched_out = _bcoo_dot_general(lhs_data, lhs_indices, rhs, lhs_spinfo=SparseInfo(new_shape),
|
||||
(len(lhs_spinfo.shape), rhs.ndim), (0, rhs_bdim), dimension_numbers)
|
||||
batched_out = _bcoo_dot_general(new_lhs_data, new_lhs_indices, rhs, lhs_spinfo=new_lhs_spinfo,
|
||||
dimension_numbers=new_dimension_numbers)
|
||||
return batched_out, result_batch_dim
|
||||
|
||||
|
@ -734,10 +734,8 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self.assertArraysEqual(M.todense(), jnp.empty(shape, dtype))
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(n_batch=n_batch, n_dense=n_dense)
|
||||
for n_batch in range(3)
|
||||
for n_dense in range(3 - n_batch)
|
||||
],
|
||||
[dict(n_batch=layout.n_batch, n_dense=layout.n_dense)
|
||||
for layout in iter_sparse_layouts((3, 3))],
|
||||
N=[3, 5],
|
||||
M=[None, 4],
|
||||
k=[-3, -1, 0, 2, 4],
|
||||
@ -955,6 +953,8 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
# Dense dimensions not yet fully supported in reverse mode.
|
||||
modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev']
|
||||
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol)
|
||||
self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol,
|
||||
bdims=self._random_bdims(props.n_batch, len(props.rhs_shape)))
|
||||
|
||||
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
|
||||
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
||||
|
Loading…
x
Reference in New Issue
Block a user