mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Re-enable cudnn_fusion_test on A100.
Check that the required cuDNN version is available.
This commit is contained in:
parent
2a41c04fef
commit
e083c08001
@ -1523,6 +1523,7 @@ jax_multiplatform_test(
|
||||
srcs = ["cudnn_fusion_test.py"],
|
||||
enable_backends = [],
|
||||
enable_configs = [
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
|
@ -15,6 +15,7 @@
|
||||
from absl.testing import absltest, parameterized
|
||||
from unittest import SkipTest
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.cudnn import cudnn_fusion
|
||||
@ -26,8 +27,9 @@ jax.config.parse_flags_with_absl()
|
||||
class CudnnFusionTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if (not jtu.test_device_matches(["cuda"]) or
|
||||
not jtu.is_cuda_compute_capability_at_least("9.0")):
|
||||
self.skipTest("Only works on >= sm90 GPUs")
|
||||
not jtu.is_cuda_compute_capability_at_least("8.0") or
|
||||
cuda_versions.cudnn_get_version() < 90110):
|
||||
self.skipTest("Only works on >= sm80 GPUs with cuDNN 9.1.1+")
|
||||
super().setUp()
|
||||
|
||||
@parameterized.parameters(["", "pmap"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user