333 Commits

Author SHA1 Message Date
Gleb Pobudzey
54691b125a [Mosaic GPU] Support reads/writes from SMEM to WGMMARowFragLayout arrays.
PiperOrigin-RevId: 738121106
2025-03-18 13:23:07 -07:00
Benjamin Chetioui
875099b25d [Mosaic GPU] Enable the new transform inference pass in the warpgroup lowering.
A couple of dummy transform inference rules needed to be added in order to
contend with parts of the lowering that do not use the dialect yet, along with
a transform inference rule for `memref.view`.

PiperOrigin-RevId: 738089782
2025-03-18 11:51:43 -07:00
Benjamin Chetioui
1e36cbe597 [Mosaic GPU] Raise a NotImplementedError if swizzle=16.
Unswizzled MMAs don't lower correctly, and are not currently intended to be
supported.

PiperOrigin-RevId: 737981373
2025-03-18 06:29:13 -07:00
Adam Paszke
8da93249d2 [Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
This allows us to significantly simplify the generated PTX/SASS,
which is currently cluttered with LLVM trying to align slices to
start at bit 0 and failing to CSE the right shifts.

PiperOrigin-RevId: 737967890
2025-03-18 05:38:49 -07:00
Benjamin Chetioui
ba2f7c9ad9 [Mosaic GPU] Add transform inference rule for mgpu.slice_smem.
PiperOrigin-RevId: 737957778
2025-03-18 04:53:54 -07:00
Adam Paszke
d4bd2570ae [Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGMMA friendly layouts
PiperOrigin-RevId: 737956598
2025-03-18 04:47:51 -07:00
Adam Paszke
34cd5b0d74 [Mosaic GPU] Remove sub-byte conversion restriction
XLA:GPU recently changed its endianness to little endian to better match LLVM
and the rest of the CUDA ecosystem, so we can lift the earlier restrictions.
PiperOrigin-RevId: 737934373
2025-03-18 03:13:21 -07:00
Benjamin Chetioui
9a686e0bf3 [Mosaic GPU] Add initial transform inference rules for vector.{load,store}.
PiperOrigin-RevId: 737703568
2025-03-17 12:08:07 -07:00
Adam Paszke
3649da56fc [Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes to vector length
We can now perform the conversion in groups of 2, 4 or even 8 elements at a time.

PiperOrigin-RevId: 737626600
2025-03-17 08:37:17 -07:00
Sergei Lebedev
a7e5eaee56 [pallas:mosaic_gpu] jnp.reduce_sum now works for >1D arrays
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00
Adam Paszke
89b21de62a [Mosaic GPU] Add support for changing the layout before the upcast
This lets us save on 2 ALU instructions (3x select becomes 1x prmt).

PiperOrigin-RevId: 737550598
2025-03-17 03:26:48 -07:00
Adam Paszke
2bdd9c8797 [Mosaic GPU] Add support for fast WGMMA layout changes after 8- to 16-bit upcast
PiperOrigin-RevId: 737542885
2025-03-17 02:52:16 -07:00
Benjamin Chetioui
5098d2ef49 [Mosaic GPU][NFC] Simplify implementation for in_{layout,transforms}_for_operand utils.
PiperOrigin-RevId: 736809960
2025-03-14 03:52:10 -07:00
Benjamin Chetioui
d09df7c8ab [Mosaic GPU] Add transform inference rules for mgpu.async_{load,store}.
PiperOrigin-RevId: 736795784
2025-03-14 02:37:55 -07:00
Benjamin Chetioui
d028354abb [Mosaic GPU] Introduce an initial transform inference pass.
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.

The pass isn't called anywhere yet.

PiperOrigin-RevId: 736758754
2025-03-13 23:22:59 -07:00
Christos Perivolaropoulos
b34f56bfd7 [mosaic_gpu/pallas:mgpu] Eradicate wgmma_layout
PiperOrigin-RevId: 736187550
2025-03-12 10:47:48 -07:00
Sergei Lebedev
e33f3fc48b [pallas:mosaic_gpu] Added support for reductions to the WG lowering
Note that

* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
  the dialect lowering AFAICT.

PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
Adam Paszke
30a9e1b3bf [Mosaic GPU] Add support for .cta_group::2 MMA with n=512 on Blackwell
This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:

```
Contributing CTA |    0    |    1    |    0    |    1    |
N local          |   0:128 |   0:128 | 128:256 | 128:256 |
N                |   0:128 | 256:384 | 128:256 | 384:512 |
```

You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!

We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).

Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.

PiperOrigin-RevId: 735791426
2025-03-11 09:53:20 -07:00
Dimitar (Mitko) Asenov
d2bf034c47 [Mosaic GPU] Test the wgmma_op lowering when a is in registers.
I had to add support for wgmma layout in vector_load. Not sure if this is useful outside the test.

PiperOrigin-RevId: 735384104
2025-03-10 08:25:43 -07:00
Sergei Lebedev
91340ea0a7 [pallas:mosaic_gpu] Added support for math functions to the WG lowering
PiperOrigin-RevId: 735333893
2025-03-10 05:08:19 -07:00
Benjamin Chetioui
75d8702023 [Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
Enable some of the pre-existing Pallas `ops_test`s for testing.

PiperOrigin-RevId: 735293084
2025-03-10 02:14:39 -07:00
Christos Perivolaropoulos
eeccc67c0b [mgpu] Debug print arrays.
PiperOrigin-RevId: 734576543
2025-03-07 08:58:25 -08:00
Adam Paszke
1bef8b61af [Mosaic GPU] Add a better explanation for the transposed layout
Thanks to @bchetioui for the discussion!

PiperOrigin-RevId: 734564672
2025-03-07 08:19:32 -08:00
Sergei Lebedev
928caf83ee [pallas:mosaic_gpu] copy_smem_to_gmem now allows skipping cp.async.commit_group
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.

PiperOrigin-RevId: 734553668
2025-03-07 07:43:54 -08:00
Adam Paszke
65462fe684 [Mosaic GPU] Add a new layout to help with transposing WGMMA results
PiperOrigin-RevId: 734553651
2025-03-07 07:42:01 -08:00
Adam Paszke
85c6b6a128 [Mosaic GPU] Add support for tiling stores to refs using small tiling
The difficulty here is that our register tiling is based on the (64, 8)
shape, while the memory tiling is now (8, swizzle // bytewidth). Before,
we would assume that each register tile fits neatly within a single
memory tile, but now it is obviously not the case. Luckily, it wasn't
too hard to add.

PiperOrigin-RevId: 734517000
2025-03-07 05:19:11 -08:00
Dimitar (Mitko) Asenov
5d64b3d2dd [Mosaic GPU] Fix scf.ForOp lowering to put lowered ops at the right place.
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.

The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.

PiperOrigin-RevId: 734157829
2025-03-06 08:40:19 -08:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Adam Paszke
8df00e2666 [Mosaic GPU] Remove support for large tiles on Blackwell
We don't have many Blackwell kernels yet, so let's begin the deprecation there!
Small tiles have clearer semantics when it comes to transposes too, which allows
us to enable more test cases.

PiperOrigin-RevId: 733786884
2025-03-05 10:34:53 -08:00
Adam Paszke
4493889cda [Mosaic GPU] Add support for small tiles for (WG)MMA LHS
Thanks to the previous refactor the change is quite trivial and mostly
focuses on adding tests.

PiperOrigin-RevId: 733754797
2025-03-05 09:01:20 -08:00
Adam Paszke
d119138766 [Mosaic GPU][NFC] Refactor MMA SMEM descriptor creation
This makes the code path uniform for LHS/RHS and greatly clarifies the
magical computation of LBO/SBO. This change should make it significantly
easier for us to enable small tile support for the LHS.

PiperOrigin-RevId: 733737302
2025-03-05 08:06:06 -08:00
Christos Perivolaropoulos
51719a1afe [mgpu] Non-vector untiled stores for tiling layouts.
Useful for storing in memrefs where the minormost stride is >1.

PiperOrigin-RevId: 733551038
2025-03-04 19:41:04 -08:00
Gleb Pobudzey
43b6be0e81 [Mosaic GPU] Add lowering for log, and a fast path using log2.
PiperOrigin-RevId: 733411276
2025-03-04 11:50:50 -08:00
Adam Paszke
cdae5fcfc7 [Mosaic GPU] Make sure to do the async proxy fence before wargroup sync
This is the ordering we want for a proper release of generic SMEM stores
into the async proxy. The old order was problematic: once the warpgroup
barrier was complete, some warps could get deselected before they get to
the fence. For as long as the first warp would make progress, it could go
through the fence along and start issuing TMA copies before other warps
have synchronized with the async proxy.

I have not observed this problem in any of our kernels so far, but this
order seems safer to me.

PiperOrigin-RevId: 733333814
2025-03-04 08:11:15 -08:00
Adam Paszke
e9f95cc3a7 [Mosaic GPU] Make the small WGMMA tile independent of transpose flags
Now the small tiling is always `(8, swizzle // bytewidth(dtype))`, no matter whether the input
is transposed or not. This should simply the follow-up refactoring of the code and make it easier
to enable small tiling for LHS too.

PiperOrigin-RevId: 732933005
2025-03-03 08:30:57 -08:00
Christos Perivolaropoulos
b9ebd9188f [mgpu] Forach in tiled layout.
PiperOrigin-RevId: 732872906
2025-03-03 04:31:59 -08:00
Adam Paszke
11e6cfbc6a [Mosaic GPU][NFC] Move the calculation of group strides into _validate_mma
This allows us to unify this logic between Hopper and Blackwell.

PiperOrigin-RevId: 732862875
2025-03-03 03:51:20 -08:00
Adam Paszke
3038348f23 [Mosaic GPU][NFC] Clean up the computation of group strides
PiperOrigin-RevId: 732849235
2025-03-03 02:50:48 -08:00
Dimitar (Mitko) Asenov
3b305c6617 [Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op.
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.

PiperOrigin-RevId: 732618760
2025-03-02 03:17:13 -08:00
Dimitar (Mitko) Asenov
c60ef5a2a1 [Mosaic GPU] Wire up the slice_lengths and indices operands in lowering of the MLIR dialect.
This enables slicing via TMA and is needed for pipelining.

PiperOrigin-RevId: 732613803
2025-03-02 02:43:47 -08:00
Adam Paszke
bb96226dd8 [Mosaic GPU] Add support for small RHS tile sizes in WGMMA
This is useful for more fine-grained autotuning and can help avoid
wave quantization effects.

PiperOrigin-RevId: 732105219
2025-02-28 05:41:30 -08:00
Benjamin Chetioui
1bc36e623b [Mosaic GPU][NFC] Delete workaround for dialect bindings before jaxlib 0.5.1.
PiperOrigin-RevId: 732102282
2025-02-28 05:25:53 -08:00
Benjamin Chetioui
7c46480eab [Mosaic GPU] Fix as_dialect_barrier_memref to take into account BarrierRef's offset.
PiperOrigin-RevId: 732098299
2025-02-28 05:06:57 -08:00
Benjamin Chetioui
a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
Benjamin Chetioui
abfe2d080e [Mosaic GPU][NFC] Move some functions to a new file called inference_utils.py.
The intent is to move utils that are useful for both layout inference and
transform inference to a shared location.

PiperOrigin-RevId: 732067659
2025-02-28 03:02:59 -08:00
Adam Paszke
092ea35301 [Mosaic GPU][NFC] Start refactoring the MMA parameter inference
The CUDA 12.8 release significantly improved the MMA docs, letting us
improve upon the previously used "magic number" scheme. Sadly, the docs
are still incorrect, but at least I can begin to make some sense of those
parameters.

PiperOrigin-RevId: 732033585
2025-02-28 00:50:20 -08:00
Adrian Kuegel
de4d047852 Change int4 packing from big-endian to little-endian
LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.

PiperOrigin-RevId: 731731530
2025-02-27 08:13:43 -08:00
Adam Paszke
99a12ef9ea [Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Adam Paszke
1de2f839d5 [Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in MGPU lowering
PiperOrigin-RevId: 731253431
2025-02-26 04:03:57 -08:00