mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 00:56:05 +00:00

The stock MLIR pipeline was a good way to get the prototype off the ground, but its default passes can be problematic. In particular, the gpu.launch is compiled into a sequence of instructions that load the kernel onto the GPU, run the kernel and immediately unload it again. This has the correct semantics, but loading the kernel is both expensive and forces a synchronization point, which leads to performance issues. To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits each function that has it into two functions: one that preloads the kernel onto the GPU, and another one that consumes the handle produced by the previous one. We call the first function at compile-time, while only the second one is used at run-time. There are other overheads in MLIR's implementation of kernel launch, but I will fix those later. PiperOrigin-RevId: 627670773