mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Disable cudnn_fusion_test on A100.
This test only seems to pass on H100 at the moment. PiperOrigin-RevId: 681070398
This commit is contained in:
parent
28098bef93
commit
1260ebbe05
@ -1496,9 +1496,8 @@ jax_py_test(
|
||||
jax_multiplatform_test(
|
||||
name = "cudnn_fusion_test",
|
||||
srcs = ["cudnn_fusion_test.py"],
|
||||
enable_backends = ["gpu"],
|
||||
enable_backends = [],
|
||||
enable_configs = [
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
|
@ -26,8 +26,8 @@ jax.config.parse_flags_with_absl()
|
||||
class CudnnFusionTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if (not jtu.test_device_matches(["cuda"]) or
|
||||
not jtu.is_cuda_compute_capability_at_least("8.0")):
|
||||
self.skipTest("Only works on >= sm80 GPUs")
|
||||
not jtu.is_cuda_compute_capability_at_least("9.0")):
|
||||
self.skipTest("Only works on >= sm90 GPUs")
|
||||
super().setUp()
|
||||
|
||||
@parameterized.parameters(["", "pmap"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user