From b320dc2e5e506ffca6a03813d7c34e339a2d6f30 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 9 Sep 2024 12:19:35 +0000 Subject: [PATCH] Fix and reenable cudnn_fusion_test. Disable XLA autotuning fallback to cuBLAS so that the tested fusion always executes through cuDNN. --- tests/BUILD | 5 +---- tests/cudnn_fusion_test.py | 6 ++++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 134e7a4ba..3442048e2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1501,10 +1501,7 @@ jax_multiplatform_test( "gpu_a100", "gpu_h100", ], - tags = [ - "multiaccelerator", - "notap", # TODO(phawkins): this test fails in our internal CI. - ], + tags = ["multiaccelerator"], ) exports_files( diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py index e70ba1236..320fba370 100644 --- a/tests/cudnn_fusion_test.py +++ b/tests/cudnn_fusion_test.py @@ -58,11 +58,13 @@ class CudnnFusionTest(jtu.JaxTestCase): self.assertIn('custom_call_target="__cudnn$fusion"', hlo) self.assertIn("called_computations=", hlo) - hlo_after_opt = lowered.compile().as_text() + compiled = lowered.compile({"xla_gpu_cublas_fallback": False}) + hlo_after_opt = compiled.as_text() + self.assertIn("kind=kCustom", hlo_after_opt) self.assertIn("plan_id", hlo_after_opt) - self.assertAllClose(jitted(x, y, z), fn(x, y, z)) + self.assertAllClose(compiled(x, y, z), fn(x, y, z)) if __name__ == '__main__':