mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Temporarily disable GQA attention tests on GPU, which were broken by a Triton integrate.
PiperOrigin-RevId: 715516188
This commit is contained in:
parent
c78487d23d
commit
d1810b42cb
@ -14,6 +14,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -168,6 +169,7 @@ class DecodeAttentionTest(PallasBaseTest):
|
||||
for return_residuals in [False, True]
|
||||
])
|
||||
@jax.numpy_dtype_promotion("standard")
|
||||
@unittest.skip("TODO(b/389925439): gqa tests started failing after triton integrate")
|
||||
def test_gqa(
|
||||
self,
|
||||
batch_size,
|
||||
|
Loading…
x
Reference in New Issue
Block a user