XLA:GPU custom call design is far from ideal, as there's apparently no way to figure
out the CUDA context that will be used to run an HLO module before the custom call is
first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid
handle errors due to the load and launch happening in different CUDA contexts...
Also fix up build_wheel.py to match the rename of the runtime lib.
PiperOrigin-RevId: 629401858
This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.
PiperOrigin-RevId: 629069842
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...
With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.
PiperOrigin-RevId: 628430358
The other JAX profiling tools are a little heavyweight when we only care about
timing a single kernel programatically.
Also adapt wgmma.py to match failures triggered by upstream MLIR changes.
PiperOrigin-RevId: 628096973
Starting in June 2023 we have switched the CPU lowering for FFT to use
the new custom call dynamic_ducc_fft. We are now out of the backwards
compatibility window and we remove the old ducc_fft.
We need to keep dynamic_ducc_fft a little bit longer (May 2024).
PiperOrigin-RevId: 627981921
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.
To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.
There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.
PiperOrigin-RevId: 627670773
This allows lowering while loops of a more general form than "for i" loops.
Improving generality here allows us to implement more interesting dynamic looping behaviors, such as progressive scans in VMEM.
PiperOrigin-RevId: 625411151
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target
The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.
Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.
PiperOrigin-RevId: 623285804
The IFRT `PluginProgram` is simply a wrapper for arbitrary byte-strings: an IFRT backend that recognizes `PluginProgram` can interpret the byte-string in any way it sees fit.
PiperOrigin-RevId: 621258245
This was a little difficult because our current dialect conversion setup assumes 1-1 type conversions.
I think everything works out fine for as long as we never pass memrefs between basic blocks (i.e.
for as long as we never have memrefs as loop carry or return them from conditionals).
TODO: I still need to make sure that the changes to the TPU dialect are backwards-compatible.
I am afraid that the signature change in MemRefSliceOp might not be.
PiperOrigin-RevId: 617755035