From e083c0800170927ffaeade5b846c857673bf17cb Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 2 Oct 2024 12:44:21 +0200 Subject: [PATCH] Re-enable cudnn_fusion_test on A100. Check that the required cuDNN version is available. --- tests/BUILD | 1 + tests/cudnn_fusion_test.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 58de34049..9b6b0bf66 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1523,6 +1523,7 @@ jax_multiplatform_test( srcs = ["cudnn_fusion_test.py"], enable_backends = [], enable_configs = [ + "gpu_a100", "gpu_h100", ], tags = ["multiaccelerator"], diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py index 151cb72be..7dc0571bc 100644 --- a/tests/cudnn_fusion_test.py +++ b/tests/cudnn_fusion_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest, parameterized from unittest import SkipTest from jax._src import test_util as jtu +from jax._src.lib import cuda_versions import jax import jax.numpy as jnp from jax._src.cudnn import cudnn_fusion @@ -26,8 +27,9 @@ jax.config.parse_flags_with_absl() class CudnnFusionTest(jtu.JaxTestCase): def setUp(self): if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on >= sm90 GPUs") + not jtu.is_cuda_compute_capability_at_least("8.0") or + cuda_versions.cudnn_get_version() < 90110): + self.skipTest("Only works on >= sm80 GPUs with cuDNN 9.1.1+") super().setUp() @parameterized.parameters(["", "pmap"])