mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

The stock MLIR pipeline was a good way to get the prototype off the ground, but its default passes can be problematic. In particular, the gpu.launch is compiled into a sequence of instructions that load the kernel onto the GPU, run the kernel and immediately unload it again. This has the correct semantics, but loading the kernel is both expensive and forces a synchronization point, which leads to performance issues. To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits each function that has it into two functions: one that preloads the kernel onto the GPU, and another one that consumes the handle produced by the previous one. We call the first function at compile-time, while only the second one is used at run-time. There are other overheads in MLIR's implementation of kernel launch, but I will fix those later. PiperOrigin-RevId: 627670773
28 lines
900 B
C++
28 lines
900 B
C++
/* Copyright 2024 The JAX Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef JAXLIB_MOSAIC_GPU_LAUNCH_LOWERING_H_
|
|
#define JAXLIB_MOSAIC_GPU_LAUNCH_LOWERING_H_
|
|
|
|
namespace mosaic {
|
|
namespace gpu {
|
|
|
|
void registerGpuLaunchLoweringPass();
|
|
|
|
} // namespace gpu
|
|
} // namespace mosaic
|
|
|
|
#endif // JAXLIB_MOSAIC_GPU_LAUNCH_LOWERING_H_
|