Disable cudnn_fusion_test on A100.

This test only seems to pass on H100 at the moment.

PiperOrigin-RevId: 681070398
This commit is contained in:
Peter Hawkins 2024-10-01 10:18:06 -07:00 committed by jax authors
parent 28098bef93
commit 1260ebbe05
2 changed files with 3 additions and 4 deletions

View File

@ -1496,9 +1496,8 @@ jax_py_test(
jax_multiplatform_test(
name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"],
enable_backends = ["gpu"],
enable_backends = [],
enable_configs = [
"gpu_a100",
"gpu_h100",
],
tags = ["multiaccelerator"],

View File

@ -26,8 +26,8 @@ jax.config.parse_flags_with_absl()
class CudnnFusionTest(jtu.JaxTestCase):
def setUp(self):
if (not jtu.test_device_matches(["cuda"]) or
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on >= sm80 GPUs")
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("Only works on >= sm90 GPUs")
super().setUp()
@parameterized.parameters(["", "pmap"])