mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix and reenable cudnn_fusion_test.
Disable XLA autotuning fallback to cuBLAS so that the tested fusion always executes through cuDNN.
This commit is contained in:
parent
15024baabf
commit
b320dc2e5e
@ -1501,10 +1501,7 @@ jax_multiplatform_test(
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
],
|
||||
tags = [
|
||||
"multiaccelerator",
|
||||
"notap", # TODO(phawkins): this test fails in our internal CI.
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
|
@ -58,11 +58,13 @@ class CudnnFusionTest(jtu.JaxTestCase):
|
||||
self.assertIn('custom_call_target="__cudnn$fusion"', hlo)
|
||||
self.assertIn("called_computations=", hlo)
|
||||
|
||||
hlo_after_opt = lowered.compile().as_text()
|
||||
compiled = lowered.compile({"xla_gpu_cublas_fallback": False})
|
||||
hlo_after_opt = compiled.as_text()
|
||||
|
||||
self.assertIn("kind=kCustom", hlo_after_opt)
|
||||
self.assertIn("plan_id", hlo_after_opt)
|
||||
|
||||
self.assertAllClose(jitted(x, y, z), fn(x, y, z))
|
||||
self.assertAllClose(compiled(x, y, z), fn(x, y, z))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user