rocm_jax/benchmarks
Daniel Suo c76b8adcf5 Add Jax tracing micro benchmarks.
Add a first benchmark for tracing/lowering pallas splash attention.

Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan.

---------------------------------------------------------------------------------
Benchmark                                       Time             CPU   Iterations
---------------------------------------------------------------------------------
test_pallas_mqa_splash_attention_trace       39.8 ms         39.8 ms           19
test_pallas_mqa_splash_attention_lower       42.1 ms         41.9 ms           18

PiperOrigin-RevId: 742259409
2025-04-09 18:05:14 +00:00
..