[XLA:GPU] Skip small tile sizes for sparse gemms on Ampere as well. Enable the JAX test again that has been failing.

PiperOrigin-RevId: 695360850
This commit is contained in:
Christian Sigg 2024-11-11 08:57:11 -08:00 committed by jax authors
parent 8a7bf2e4b0
commit f18f62a5d2

View File

@ -47,9 +47,6 @@ class SpmmTest(jtu.JaxTestCase):
)
@jtu.run_on_devices("gpu")
def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx):
if not jtu.is_cuda_compute_capability_at_least("9.0"):
self.skipTest("Skipping test on Ampere because of bug b/377940729")
# Build keyword arguments
kwargs = {
"dimension_numbers": (((1,), (1,)), (tuple(), tuple())),
@ -96,9 +93,6 @@ class SpmmTest(jtu.JaxTestCase):
)
@jtu.run_on_devices("gpu")
def test_types(self, lhs_type, rhs_type, output_type):
if not jtu.is_cuda_compute_capability_at_least("9.0"):
self.skipTest("Skipping test on Ampere because of bug b/377940729")
tile_m, tile_n, tile_k = 64, 32, 128
# Build input data