A couple of dummy transform inference rules needed to be added in order to
contend with parts of the lowering that do not use the dialect yet, along with
a transform inference rule for `memref.view`.
PiperOrigin-RevId: 738089782
This allows us to significantly simplify the generated PTX/SASS,
which is currently cluttered with LLVM trying to align slices to
start at bit 0 and failing to CSE the right shifts.
PiperOrigin-RevId: 737967890
XLA:GPU recently changed its endianness to little endian to better match LLVM
and the rest of the CUDA ecosystem, so we can lift the earlier restrictions.
PiperOrigin-RevId: 737934373
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.
The pass isn't called anywhere yet.
PiperOrigin-RevId: 736758754
Note that
* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
the dialect lowering AFAICT.
PiperOrigin-RevId: 736138554
This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:
```
Contributing CTA | 0 | 1 | 0 | 1 |
N local | 0:128 | 0:128 | 128:256 | 128:256 |
N | 0:128 | 256:384 | 128:256 | 384:512 |
```
You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!
We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).
Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.
PiperOrigin-RevId: 735791426
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.
PiperOrigin-RevId: 734553668
The difficulty here is that our register tiling is based on the (64, 8)
shape, while the memory tiling is now (8, swizzle // bytewidth). Before,
we would assume that each register tile fits neatly within a single
memory tile, but now it is obviously not the case. Luckily, it wasn't
too hard to add.
PiperOrigin-RevId: 734517000
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.
The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.
PiperOrigin-RevId: 734157829
We don't have many Blackwell kernels yet, so let's begin the deprecation there!
Small tiles have clearer semantics when it comes to transposes too, which allows
us to enable more test cases.
PiperOrigin-RevId: 733786884
This makes the code path uniform for LHS/RHS and greatly clarifies the
magical computation of LBO/SBO. This change should make it significantly
easier for us to enable small tile support for the LHS.
PiperOrigin-RevId: 733737302
This is the ordering we want for a proper release of generic SMEM stores
into the async proxy. The old order was problematic: once the warpgroup
barrier was complete, some warps could get deselected before they get to
the fence. For as long as the first warp would make progress, it could go
through the fence along and start issuing TMA copies before other warps
have synchronized with the async proxy.
I have not observed this problem in any of our kernels so far, but this
order seems safer to me.
PiperOrigin-RevId: 733333814
Now the small tiling is always `(8, swizzle // bytewidth(dtype))`, no matter whether the input
is transposed or not. This should simply the follow-up refactoring of the code and make it easier
to enable small tiling for LHS too.
PiperOrigin-RevId: 732933005
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.
PiperOrigin-RevId: 732618760
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.
This is in preparation of changes implementing transform inference.
PiperOrigin-RevId: 732091266
The CUDA 12.8 release significantly improved the MMA docs, letting us
improve upon the previously used "magic number" scheme. Sadly, the docs
are still incorrect, but at least I can begin to make some sense of those
parameters.
PiperOrigin-RevId: 732033585
LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.
PiperOrigin-RevId: 731731530