Merge pull request #13431 from jakevdp:fix-sparse-matmul-warning

PiperOrigin-RevId: 491649932
This commit is contained in:
jax authors 2022-11-29 08:55:23 -08:00
commit 21bab5efab

View File

@ -2076,6 +2076,7 @@ class BCOOTest(sptu.SparseTestCase):
rhs_dtype=all_dtypes,
)
@jax.default_matmul_precision("float32")
@jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning)
def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
# TODO(b/259538729): Disable gpu test when type promotion is required.
# BCOO type promotion calls `convert_element_type`, which further calls