mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add part (non-quantized K/V pages) of paged_attention_kernel tests back for TPU v6.
The paged_attention_kernel tests for TPU v6 was disabled in the past but I discovered that all the failing tests have `are_kv_quantized=True`. So we can still test the non-quantized part on TPU v6. PiperOrigin-RevId: 717969073
This commit is contained in:
parent
70a5175d0a
commit
96a3ed36c7
@ -108,10 +108,6 @@ def _megacore_enabled():
|
||||
|
||||
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
||||
class PagedAttentionKernelTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jtu.is_device_tpu_at_least(6):
|
||||
self.skipTest('Not implemented for TPU v6')
|
||||
|
||||
@parameterized.product(
|
||||
dtype=(jnp.float32, jnp.bfloat16),
|
||||
@ -144,6 +140,8 @@ class PagedAttentionKernelTest(jtu.JaxTestCase):
|
||||
# weight and scale tensors for quantized tensors. When enabled on TPUv4,
|
||||
# the tests sometimes failed with resource exhausted error.
|
||||
self.skipTest("Quantization is not supported on TPU v4")
|
||||
if jtu.is_device_tpu_at_least(6) and are_kv_quantized:
|
||||
self.skipTest("Quantization is not supported on TPU v6")
|
||||
if megacore_mode and not _megacore_enabled():
|
||||
self.skipTest("Megacore is only available on TPU v4 or TPU v5p")
|
||||
if num_kv_heads % 2 != 0 and megacore_mode == "kv_head":
|
||||
|
Loading…
x
Reference in New Issue
Block a user