diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py index 686575ee4..9ecf30eb6 100644 --- a/tests/sparse_nm_test.py +++ b/tests/sparse_nm_test.py @@ -47,9 +47,6 @@ class SpmmTest(jtu.JaxTestCase): ) @jtu.run_on_devices("gpu") def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): - if not jtu.is_cuda_compute_capability_at_least("9.0"): - self.skipTest("Skipping test on Ampere because of bug b/377940729") - # Build keyword arguments kwargs = { "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), @@ -96,9 +93,6 @@ class SpmmTest(jtu.JaxTestCase): ) @jtu.run_on_devices("gpu") def test_types(self, lhs_type, rhs_type, output_type): - if not jtu.is_cuda_compute_capability_at_least("9.0"): - self.skipTest("Skipping test on Ampere because of bug b/377940729") - tile_m, tile_n, tile_k = 64, 32, 128 # Build input data