mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00

The other JAX profiling tools are a little heavyweight when we only care about timing a single kernel programatically. Also adapt wgmma.py to match failures triggered by upstream MLIR changes. PiperOrigin-RevId: 628096973