Merge pull request #14500 from jakevdp:bcsr-matmul-test

PiperOrigin-RevId: 510034750
This commit is contained in:
jax authors 2023-02-15 21:26:06 -08:00
commit d8514d0ec6

View File

@ -2227,9 +2227,8 @@ class BCSRTest(sptu.SparseTestCase):
# Dense dimensions not yet fully supported in reverse mode.
modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev']
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol)
# TODO: add this once bcsr_broadcast_in_dim & bcsr_concatenate are implemented
# self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol,
# bdims=self._random_bdims(props.n_batch, len(props.rhs_shape)))
self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol,
bdims=self._random_bdims(props.n_batch, len(props.rhs_shape)))
@jtu.sample_product(
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense)