mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14500 from jakevdp:bcsr-matmul-test
PiperOrigin-RevId: 510034750
This commit is contained in:
commit
d8514d0ec6
@ -2227,9 +2227,8 @@ class BCSRTest(sptu.SparseTestCase):
|
||||
# Dense dimensions not yet fully supported in reverse mode.
|
||||
modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev']
|
||||
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol)
|
||||
# TODO: add this once bcsr_broadcast_in_dim & bcsr_concatenate are implemented
|
||||
# self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol,
|
||||
# bdims=self._random_bdims(props.n_batch, len(props.rhs_shape)))
|
||||
self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol,
|
||||
bdims=self._random_bdims(props.n_batch, len(props.rhs_shape)))
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense)
|
||||
|
Loading…
x
Reference in New Issue
Block a user