mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

Add a first benchmark for tracing/lowering pallas splash attention. Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan. --------------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------------- test_pallas_mqa_splash_attention_trace 39.8 ms 39.8 ms 19 test_pallas_mqa_splash_attention_lower 42.1 ms 41.9 ms 18 PiperOrigin-RevId: 742259409