Temporarily disable GQA attention tests on GPU, which were broken by a Triton integrate.

PiperOrigin-RevId: 715516188
This commit is contained in:
Peter Hawkins 2025-01-14 13:47:56 -08:00 committed by jax authors
parent c78487d23d
commit d1810b42cb

View File

@ -14,6 +14,7 @@
import os
import sys
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@ -168,6 +169,7 @@ class DecodeAttentionTest(PallasBaseTest):
for return_residuals in [False, True]
])
@jax.numpy_dtype_promotion("standard")
@unittest.skip("TODO(b/389925439): gqa tests started failing after triton integrate")
def test_gqa(
self,
batch_size,