Fix and reenable cudnn_fusion_test.

Disable XLA autotuning fallback to cuBLAS so that the tested fusion
always executes through cuDNN.
This commit is contained in:
Ilia Sergachev 2024-09-09 12:19:35 +00:00
parent 15024baabf
commit b320dc2e5e
2 changed files with 5 additions and 6 deletions

View File

@ -1501,10 +1501,7 @@ jax_multiplatform_test(
"gpu_a100",
"gpu_h100",
],
tags = [
"multiaccelerator",
"notap", # TODO(phawkins): this test fails in our internal CI.
],
tags = ["multiaccelerator"],
)
exports_files(

View File

@ -58,11 +58,13 @@ class CudnnFusionTest(jtu.JaxTestCase):
self.assertIn('custom_call_target="__cudnn$fusion"', hlo)
self.assertIn("called_computations=", hlo)
hlo_after_opt = lowered.compile().as_text()
compiled = lowered.compile({"xla_gpu_cublas_fallback": False})
hlo_after_opt = compiled.as_text()
self.assertIn("kind=kCustom", hlo_after_opt)
self.assertIn("plan_id", hlo_after_opt)
self.assertAllClose(jitted(x, y, z), fn(x, y, z))
self.assertAllClose(compiled(x, y, z), fn(x, y, z))
if __name__ == '__main__':