1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-22 19:46:05 +00:00
Jevin Jiang 4b49c03523 Open source TPU-friendly ragged paged attention kernel.
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.

PiperOrigin-RevId: 734269519
2025-03-06 13:36:45 -08:00
..
2024-12-18 07:46:14 +00:00
2024-05-25 17:46:01 +00:00
2024-10-07 12:27:35 -07:00
2024-11-05 13:28:17 -08:00
2025-02-26 04:48:25 +00:00
2025-02-18 09:44:16 +00:00
2024-07-15 12:54:00 -07:00