Re-enable cudnn_fusion_test on A100.

Check that the required cuDNN version is available.
This commit is contained in:
Ilia Sergachev 2024-10-02 12:44:21 +02:00
parent 2a41c04fef
commit e083c08001
2 changed files with 5 additions and 2 deletions

View File

@ -1523,6 +1523,7 @@ jax_multiplatform_test(
srcs = ["cudnn_fusion_test.py"],
enable_backends = [],
enable_configs = [
"gpu_a100",
"gpu_h100",
],
tags = ["multiaccelerator"],

View File

@ -15,6 +15,7 @@
from absl.testing import absltest, parameterized
from unittest import SkipTest
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
import jax
import jax.numpy as jnp
from jax._src.cudnn import cudnn_fusion
@ -26,8 +27,9 @@ jax.config.parse_flags_with_absl()
class CudnnFusionTest(jtu.JaxTestCase):
def setUp(self):
if (not jtu.test_device_matches(["cuda"]) or
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("Only works on >= sm90 GPUs")
not jtu.is_cuda_compute_capability_at_least("8.0") or
cuda_versions.cudnn_get_version() < 90110):
self.skipTest("Only works on >= sm80 GPUs with cuDNN 9.1.1+")
super().setUp()
@parameterized.parameters(["", "pmap"])