[Pallas MGPU] Disable XLA:GPU autotuning in attention tests

We don't care about performance of the reference impl, we only use it for
correctness testing. More importantly, it works around a deadlock at compile
time that sometimes happens when testing large batch sizes.

PiperOrigin-RevId: 703521029
This commit is contained in:
Adam Paszke 2024-12-06 09:18:28 -08:00 committed by jax authors
parent 8b656206af
commit eda7506d6b

View File

@ -494,6 +494,7 @@ jax_multiplatform_test(
srcs = ["//jax/experimental/pallas/ops/gpu:attention_mgpu.py"],
enable_backends = [],
enable_configs = ["gpu_h100_x32"],
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
tags = [
"manual",
"notap",
@ -509,6 +510,7 @@ jax_multiplatform_test(
srcs = ["mgpu_attention_test.py"],
enable_backends = [],
enable_configs = ["gpu_h100_x32"],
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
deps = [
"//jax:pallas",
"//jax:pallas_experimental_gpu_ops",