mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Pallas MGPU] Disable XLA:GPU autotuning in attention tests
We don't care about performance of the reference impl, we only use it for correctness testing. More importantly, it works around a deadlock at compile time that sometimes happens when testing large batch sizes. PiperOrigin-RevId: 703521029
This commit is contained in:
parent
8b656206af
commit
eda7506d6b
@ -494,6 +494,7 @@ jax_multiplatform_test(
|
||||
srcs = ["//jax/experimental/pallas/ops/gpu:attention_mgpu.py"],
|
||||
enable_backends = [],
|
||||
enable_configs = ["gpu_h100_x32"],
|
||||
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
@ -509,6 +510,7 @@ jax_multiplatform_test(
|
||||
srcs = ["mgpu_attention_test.py"],
|
||||
enable_backends = [],
|
||||
enable_configs = ["gpu_h100_x32"],
|
||||
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_experimental_gpu_ops",
|
||||
|
Loading…
x
Reference in New Issue
Block a user