diff --git a/tests/BUILD b/tests/BUILD index 3442048e2..d66342dc3 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1496,9 +1496,8 @@ jax_py_test( jax_multiplatform_test( name = "cudnn_fusion_test", srcs = ["cudnn_fusion_test.py"], - enable_backends = ["gpu"], + enable_backends = [], enable_configs = [ - "gpu_a100", "gpu_h100", ], tags = ["multiaccelerator"], diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py index 320fba370..151cb72be 100644 --- a/tests/cudnn_fusion_test.py +++ b/tests/cudnn_fusion_test.py @@ -26,8 +26,8 @@ jax.config.parse_flags_with_absl() class CudnnFusionTest(jtu.JaxTestCase): def setUp(self): if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on >= sm80 GPUs") + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on >= sm90 GPUs") super().setUp() @parameterized.parameters(["", "pmap"])