This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.
PiperOrigin-RevId: 629069842
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...
With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.
PiperOrigin-RevId: 628430358
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.
To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.
There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.
PiperOrigin-RevId: 627670773
This allows lowering while loops of a more general form than "for i" loops.
Improving generality here allows us to implement more interesting dynamic looping behaviors, such as progressive scans in VMEM.
PiperOrigin-RevId: 625411151
This was a little difficult because our current dialect conversion setup assumes 1-1 type conversions.
I think everything works out fine for as long as we never pass memrefs between basic blocks (i.e.
for as long as we never have memrefs as loop carry or return them from conditionals).
TODO: I still need to make sure that the changes to the TPU dialect are backwards-compatible.
I am afraid that the signature change in MemRefSliceOp might not be.
PiperOrigin-RevId: 617755035
Ideally we would prefer `TypedValue<VectorType>` everywhere possible for static type checking. However, I tried the type for arrays of vregs, `xla::Array<Value>` to `xla::Array<TypedValue<VectorType>>` and ran into issues because MLIR support for arrays/ranges of `TypedValue`s seems lacking.
For example, I can't find a good way to get a `ValueRange` (which many op constructors take) from an array of `TypedValue`s without creating an intermediate vector of `Value`s. Perhaps an unsafe cast if we make the (probably not guaranteed) assumption that `sizeof(TypedValue)` equals `sizeof(Value)`.
Also note that MLIR itself uses untyped `Value`s for ranges of op results and operands even when the op definition declares them to be of a specific type.
PiperOrigin-RevId: 610509743
It was using the `op` variable from the `ExtUIOp` above (because variables declared in initializer of an if statement are available in the else branch).
PiperOrigin-RevId: 610481302
The old `tile_indices` variable was misleading and confusing because it sometimes stored indices (in the static case) and sometimes offsets with respect to the tile (in the dynamic case).
PiperOrigin-RevId: 609457122
This allows us to rely on this throughout the code and replace some checks with TPU_ASSERT_*. They have the semantics of an assert and make it clearer that it is an unexpected internal error (instead of unimplemented or invalid user input that we should handle).
Note: the original error messages for some of these checks were using the wrong input names.
PiperOrigin-RevId: 607463728
This enables the index function to select a window starting from
any element. However, the Mosaic implementation still requires it
to be at least tile aligned.
PiperOrigin-RevId: 605254616
The semantics are closer to the TPU: having a NaN input results in NaN. However, we don't respect the -0.0 vs +0.0 ordering in older TPUs.
This also fixes a mismatch where we are using `arith.maximumf` for lowering `vector.kind<maxnumf>` (instead of `arith.maximumf` for `vector.kind<maximumf>` or `arith.maxnumf` for `vector.kind<maxnumf>`).
PiperOrigin-RevId: 604849222
This normalizes loads and stores with dynamic base indices into reference
slicing followed by statically indexed loads/stores. This should both simplify
the code (we only have to deal with dynamism in slicing) and improve performance
(we might offset the address once).
PiperOrigin-RevId: 597546106
The previous patch simply changed the type we use to represent semaphores,
but didn't actually add support for any more operations. With this one,
semaphore memrefs can be allocated and (dynamically) indexed.
PiperOrigin-RevId: 597538913