mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00

Add a first benchmark for tracing/lowering pallas splash attention. Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan. --------------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------------- test_pallas_mqa_splash_attention_trace 39.8 ms 39.8 ms 19 test_pallas_mqa_splash_attention_lower 42.1 ms 41.9 ms 18 PiperOrigin-RevId: 742259409
77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Benchmarks for Jax tracing."""
|
|
|
|
import google_benchmark
|
|
import jax
|
|
from jax import random
|
|
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash
|
|
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
|
|
import numpy as np
|
|
|
|
|
|
def make_mqa_splash_attention_fn_and_args():
|
|
seed = 0
|
|
key = random.key(seed)
|
|
k1, k2, k3 = random.split(key, 3)
|
|
|
|
q_seq_len = 1024
|
|
kv_seq_len = 1024
|
|
num_q_heads = 2
|
|
head_dim_qk = 128
|
|
head_dim_v = 128
|
|
dtype = np.dtype("float32")
|
|
|
|
q = random.uniform(k1, (num_q_heads, q_seq_len, head_dim_qk), dtype=dtype)
|
|
k = random.uniform(k2, (kv_seq_len, head_dim_qk), dtype=dtype)
|
|
v = random.uniform(k3, (kv_seq_len, head_dim_v), dtype=dtype)
|
|
|
|
mask = mask_lib.NumpyMask(
|
|
mask_lib.make_random_mask((q_seq_len, kv_seq_len), sparsity=0.5, seed=0)
|
|
)
|
|
mask = mask_lib.MultiHeadMask(tuple(mask for _ in range(num_q_heads)))
|
|
block_sizes = splash.BlockSizes.get_default()
|
|
|
|
return (
|
|
jax.jit(
|
|
splash.make_splash_mqa_single_device(mask, block_sizes=block_sizes)
|
|
)
|
|
), (q, k, v)
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def test_pallas_mqa_splash_attention_trace(state):
|
|
attn, (q, k, v) = make_mqa_splash_attention_fn_and_args()
|
|
|
|
while state:
|
|
_ = attn.trace(q, k, v)
|
|
jax.clear_caches()
|
|
|
|
|
|
@google_benchmark.register
|
|
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
|
def test_pallas_mqa_splash_attention_lower(state):
|
|
attn, (q, k, v) = make_mqa_splash_attention_fn_and_args()
|
|
traced = attn.trace(q, k, v)
|
|
|
|
while state:
|
|
_ = traced.lower(lowering_platforms=("tpu",))
|
|
jax.clear_caches()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
google_benchmark.main()
|