rocm_jax/tests/pallas/tpu_ragged_paged_attention_test.py
2025-03-11 12:36:33 -07:00

302 lines
7.1 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax.experimental.pallas.ops.tpu.ragged_paged_attention import (
ragged_paged_attention,
ref_ragged_paged_attention,
validate_inputs_on_runtime,
)
import jax.numpy as jnp
jax.config.parse_flags_with_absl()
def ceil_div(x, a):
assert a != 0
return (x + a - 1) // a
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class PagedAttentionKernelTest(jtu.JaxTestCase):
def _test_ragged_paged_attention(
self,
seq_lens, # List[(q_len, kv_len)]
num_heads, # [num_q_heads, num_kv_heads]
head_dim,
page_size,
dtype,
num_pages,
*,
num_kv_pages_per_block=8,
num_queries_per_block=64,
vmem_limit_bytes=32 * 1024 * 1024,
max_num_batched_tokens=512,
max_num_seq=8,
):
if not jtu.is_device_tpu_at_least(version=4):
self.skipTest("Expect TPUv4+")
cu_q_lens = [0]
kv_lens = []
for q_len, kv_len in seq_lens:
assert q_len <= kv_len
cu_q_lens.append(cu_q_lens[-1] + q_len)
kv_lens.append(kv_len)
max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens)
max_num_seq = max(len(seq_lens), max_num_seq)
max_kv_len = max(kv_lens)
pages_per_seq = ceil_div(max_kv_len, page_size)
num_q_heads, num_kv_heads = num_heads
cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32)
kv_lens = jnp.array(kv_lens, dtype=jnp.int32)
cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0]))
kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0]))
prng_key = jax.random.key(1234)
k0, k1, k2, k3 = jax.random.split(prng_key, 4)
q = jax.random.normal(
k0,
(max_num_batched_tokens, num_q_heads, head_dim),
dtype=dtype,
)
k_pages = jax.random.normal(
k1,
(num_pages, page_size, num_kv_heads, head_dim),
dtype=dtype,
)
v_pages = jax.random.normal(
k2,
(num_pages, page_size, num_kv_heads, head_dim),
dtype=dtype,
)
page_indices = jax.random.randint(
k3, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32
)
num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32)
validate_inputs_on_runtime(
q,
k_pages,
v_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs,
)
output = ragged_paged_attention(
q,
k_pages,
v_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs=num_seqs,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
vmem_limit_bytes=vmem_limit_bytes,
)[: cu_q_lens[num_seqs[0]]]
expected = ref_ragged_paged_attention(
q,
k_pages,
v_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs=num_seqs,
)
tols = {
"float32": 0.15,
"bfloat16": 0.2,
}
tol = tols[jnp.dtype(dtype).name]
self.assertAllClose(output, expected, atol=tol, rtol=tol)
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16],
)
def test_ragged_paged_attention_basic(self, dtype):
seq_lens = [(192, 328), (128, 180), (64, 255)]
num_heads = (32, 8)
head_dim = 128
page_size = 16
num_pages = 1000
self._test_ragged_paged_attention(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
)
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16],
)
def test_ragged_paged_attention_decode_only(self, dtype):
seq_lens = [
(1, 18),
(1, 129),
(1, 597),
(1, 122),
(1, 64),
(1, 322),
(1, 463),
(1, 181),
(1, 1107),
(1, 123),
(1, 31),
(1, 18),
(1, 1229),
(1, 229),
(1, 87),
(1, 1328),
]
num_heads = (32, 8)
head_dim = 128
page_size = 16
num_pages = 1000
self._test_ragged_paged_attention(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
)
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16],
)
def test_ragged_paged_attention_prefill_only(self, dtype):
seq_lens = [
(5, 18),
(15, 129),
(120, 597),
(100, 122),
(21, 64),
(32, 322),
(251, 463),
(40, 181),
(64, 1107),
(99, 123),
(10, 31),
(5, 18),
(3, 1229),
(120, 229),
(9, 87),
(2, 1328),
]
num_heads = (32, 8)
head_dim = 128
page_size = 16
num_pages = 1000
self._test_ragged_paged_attention(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
)
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16],
)
def test_ragged_paged_attention_mixed(self, dtype):
seq_lens = [
(5, 18),
(1, 129),
(120, 597),
(1, 122),
(1, 64),
(32, 322),
(251, 463),
(1, 181),
(1, 1107),
(99, 123),
(1, 31),
(5, 18),
(3, 1229),
(117, 229),
(1, 87),
(1, 1328),
]
num_heads = (32, 8)
head_dim = 128
page_size = 16
num_pages = 1000
self._test_ragged_paged_attention(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
)
@parameterized.product(
num_seqs=[1, 5, 16],
# TODO(jevinjiang): Support more num_heads!
num_heads=[(32, 8), (32, 16), (12, 2), (4, 4)],
dtype=[jnp.float32, jnp.bfloat16],
num_kv_pages_per_block=[4, 8],
num_queries_per_block=[32, 64],
)
def test_ragged_paged_attention_complex(
self,
num_seqs,
num_heads,
dtype,
num_kv_pages_per_block,
num_queries_per_block,
):
seq_lens = []
for _ in range(num_seqs):
q_len = random.randint(1, 100)
kv_len = q_len + random.randint(0, 50)
seq_lens.append((q_len, kv_len))
# TODO(jevinjiang): Support non-128 head_dim!
head_dim = 128
page_size = 16
num_pages = 1000
self._test_ragged_paged_attention(
seq_lens,
num_heads,
head_dim,
page_size,
dtype,
num_pages,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())