1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Add part (non-quantized K/V pages) of paged_attention_kernel tests back for TPU v6.

The paged_attention_kernel tests for TPU v6 was disabled in the past but I discovered that all the failing tests have `are_kv_quantized=True`. So we can still test the non-quantized part on TPU v6.

PiperOrigin-RevId: 717969073
This commit is contained in:
jax authors 2025-01-21 10:12:14 -08:00
parent 70a5175d0a
commit 96a3ed36c7

@ -108,10 +108,6 @@ def _megacore_enabled():
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class PagedAttentionKernelTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.is_device_tpu_at_least(6):
self.skipTest('Not implemented for TPU v6')
@parameterized.product(
dtype=(jnp.float32, jnp.bfloat16),
@ -144,6 +140,8 @@ class PagedAttentionKernelTest(jtu.JaxTestCase):
# weight and scale tensors for quantized tensors. When enabled on TPUv4,
# the tests sometimes failed with resource exhausted error.
self.skipTest("Quantization is not supported on TPU v4")
if jtu.is_device_tpu_at_least(6) and are_kv_quantized:
self.skipTest("Quantization is not supported on TPU v6")
if megacore_mode and not _megacore_enabled():
self.skipTest("Megacore is only available on TPU v4 or TPU v5p")
if num_kv_heads % 2 != 0 and megacore_mode == "kv_head":