contraction with f32 accumulator and output, where the accumulator is
extended and the output truncated. For targets that do not support bf16
matmul, the lhs and rhs are extended to f32.
PiperOrigin-RevId: 642051952
- We should recursively clear layouts and any assume_layout ops if we want to override layouts in a block.
- Refactor the logic of assume layouts for block arguments to a helper function.
- Add tests for nested fori loop and while loop.
PiperOrigin-RevId: 641973011
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.
PiperOrigin-RevId: 641874183
- Support non-zero minor offsets without having to relayout (they're still a no-op).
- Remove restriction on tiling which now allows 1D packed types to work.
PiperOrigin-RevId: 640967375
- On infer-vector-layout remove some restrictions related to batch dimensions. Reshaping them doesn't matter as long as they don't combine with tiled dimensions.
- On apply-vector-layout, simplify handling of cases where the implicit tiled don't change while removing some unnecessary restrictions.
- Don't require native tiling or natural topology for this.
PiperOrigin-RevId: 640837740
Previously the rule would complain if the layouts were unsupported, but that's not
the right way to handle that situation. With this change, we simply pick a supported
configuration instead (and expect relayout to handle it).
PiperOrigin-RevId: 640190248
Pallas hasn't been using it to lower since cl/604849222, which is before we had serde, so we won't create a serde rule for this. That CL also tried to remove the support, but it had to be restored in cl/605436599 because it was breaking serialized kernels.
PiperOrigin-RevId: 639896981
We should infer layout for each terminator inside its own region and find a compatible layout for a final result if the result is based on terminators from multiple regions like scf::ifOp, scf::whileOp, scf::forOp. If no compatible layout is found, we will fall back to a normalized layout. Finally we also need to ensure the layouts in input, terminator and output are consistent across loops.
PiperOrigin-RevId: 639122434
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.
PiperOrigin-RevId: 638569750
- Enable it for minor or second-minor implicit dims for the non-no-op case.
- Don't allow output offsets for broadcasted dimensions to be non-replicated. Make sure to assign them as replicated in infer-vector-layout for all cases.
- Don't fail when both tiled dimensions are logically broadcasted but only one of them requires actual broadcasting (before, it would hit the unimplemented sublane + lane broadcast case).
PiperOrigin-RevId: 637772134
The implementation uses the new tpu.log operation in the Mosaic TPU dialect.
Note that
* the logging only happens if --xla_tpu_enable_log_recorder is set;
* only scalars can be printed;
* placeholders only accept i32 arguments at the moment.
PiperOrigin-RevId: 636585852
This CL also fixes a bug in `VectorLayout::join` where `ImplicitDim::kMinor` was considered equivalent to `ImplicitDim::kNone` when the shape's minor dimension is 1 (also needed to check that the second-minor dimension is 1).
Often handling every implicit dim case separately is more complex and error-prone.
There will be more follow-up changes to do this consistently elsewhere the code. This is also a first step towards 0D layout support.
PiperOrigin-RevId: 636250759
This drops implicit second minor dim and fixes the infer vector layout for concatenate rule when concatenating constant vectors.
PiperOrigin-RevId: 635917408
The change to tpu.td is not backwards compatible, but I made it so using the
newly added Mosaic stability layer. It's been a good exercise and it seems to
be working just fine.
Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 630060418
They are equivalent - the way values are laid out is the same - but relayouts check specifically for (1, 128). We define (1, 128) to be canonical.
PiperOrigin-RevId: 629748121
This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.
PiperOrigin-RevId: 629069842
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...
With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.
PiperOrigin-RevId: 628430358
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.
To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.
There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.
PiperOrigin-RevId: 627670773