mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[XLA:GPU] Skip small tile sizes for sparse gemms on Ampere as well. Enable the JAX test again that has been failing.
PiperOrigin-RevId: 695360850
This commit is contained in:
parent
8a7bf2e4b0
commit
f18f62a5d2
@ -47,9 +47,6 @@ class SpmmTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx):
|
||||
if not jtu.is_cuda_compute_capability_at_least("9.0"):
|
||||
self.skipTest("Skipping test on Ampere because of bug b/377940729")
|
||||
|
||||
# Build keyword arguments
|
||||
kwargs = {
|
||||
"dimension_numbers": (((1,), (1,)), (tuple(), tuple())),
|
||||
@ -96,9 +93,6 @@ class SpmmTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
def test_types(self, lhs_type, rhs_type, output_type):
|
||||
if not jtu.is_cuda_compute_capability_at_least("9.0"):
|
||||
self.skipTest("Skipping test on Ampere because of bug b/377940729")
|
||||
|
||||
tile_m, tile_n, tile_k = 64, 32, 128
|
||||
|
||||
# Build input data
|
||||
|
Loading…
x
Reference in New Issue
Block a user