27 Commits

Author SHA1 Message Date
Sergei Lebedev
a7e5eaee56 [pallas:mosaic_gpu] jnp.reduce_sum now works for >1D arrays
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00
Christos Perivolaropoulos
b34f56bfd7 [mosaic_gpu/pallas:mgpu] Eradicate wgmma_layout
PiperOrigin-RevId: 736187550
2025-03-12 10:47:48 -07:00
Sergei Lebedev
e33f3fc48b [pallas:mosaic_gpu] Added support for reductions to the WG lowering
Note that

* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
  the dialect lowering AFAICT.

PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
Sergei Lebedev
91340ea0a7 [pallas:mosaic_gpu] Added support for math functions to the WG lowering
PiperOrigin-RevId: 735333893
2025-03-10 05:08:19 -07:00
Benjamin Chetioui
75d8702023 [Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
Enable some of the pre-existing Pallas `ops_test`s for testing.

PiperOrigin-RevId: 735293084
2025-03-10 02:14:39 -07:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Dimitar (Mitko) Asenov
3b305c6617 [Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op.
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.

PiperOrigin-RevId: 732618760
2025-03-02 03:17:13 -08:00
Benjamin Chetioui
abfe2d080e [Mosaic GPU][NFC] Move some functions to a new file called inference_utils.py.
The intent is to move utils that are useful for both layout inference and
transform inference to a shared location.

PiperOrigin-RevId: 732067659
2025-02-28 03:02:59 -08:00
Adam Paszke
99a12ef9ea [Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Adam Paszke
1de2f839d5 [Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in MGPU lowering
PiperOrigin-RevId: 731253431
2025-02-26 04:03:57 -08:00
Benjamin Chetioui
5024ef213f [Mosaic GPU] Add layout inference for scf.ForOp and scf.YieldOp.
PiperOrigin-RevId: 730873769
2025-02-25 07:13:25 -08:00
Benjamin Chetioui
5312b5e35a [Mosaic GPU] Add layout inference for arith.Ext{F,SI,UI}Op and arith.Trunc{F,I}Op.
PiperOrigin-RevId: 730851596
2025-02-25 05:59:40 -08:00
Sergei Lebedev
74b2e0203f [pallas:mosaic_gpu] Use {min,max}imumf instead of {min,max}numf
PiperOrigin-RevId: 730154865
2025-02-23 09:52:48 -08:00
Sergei Lebedev
7438976e79 [pallas:mosaic_gpu] Added support for binary/comparison ops with WG semantics
PiperOrigin-RevId: 729266484
2025-02-20 15:06:27 -08:00
Benjamin Chetioui
c7199fe8a5 [Pallas/Mosaic GPU] Enable progressive lowering for integer addition.
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.

This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.

PiperOrigin-RevId: 726030853
2025-02-12 06:29:25 -08:00
Dimitar (Mitko) Asenov
3a411d883a [Mosaic GPU] Implement basic WGMMAFragLayout inference and propagation
PiperOrigin-RevId: 718781860
2025-01-23 02:48:04 -08:00
Benjamin Chetioui
6746d63364 [Mosaic GPU][NFC] Clean up import to align with stylistic guidance.
PiperOrigin-RevId: 716233876
2025-01-16 07:50:04 -08:00
Benjamin Chetioui
d3bf243342 [Mosaic GPU] Add layout inference for splat arith.ConstantOps and vector.SplatOps.
PiperOrigin-RevId: 716224880
2025-01-16 07:18:35 -08:00
Benjamin Chetioui
bc7204f003 [Mosaic GPU] Allow querying layouts from a FuncOp's block arguments if set.
The motivation behind this change is twofold:

1. it simplifies test writing (no need to produce arbitrary, manual, non-splat
   constants to produce arguments with a strided layout);
2. it'll allow running layout inference on different `FuncOp`s in isolation,
   before inlining.

While the primary motivation is to simplify test writing for upcoming changes,
`2.` is useful if we ever intend to call functions whose body's layout we have
inferred from other functions. It's not clear to me that we have a use case for
that, but the theoretical benefit is worth pointing out.

Crucially, layout inference does not set default layouts for `FuncOp`s, since
the caller may choose a different layout for its arguments. As a result, there
is also no layout inference rule for `func.FuncOp`.

PiperOrigin-RevId: 716158516
2025-01-16 03:05:41 -08:00
Benjamin Chetioui
cdf490a5d0 [Mosaic GPU][NFC] Address some previous stylistic comments.
PiperOrigin-RevId: 715772455
2025-01-15 06:21:23 -08:00
Benjamin Chetioui
1893881b5f [Mosaic GPU] Add initial layout mismatch resolution for splat/strided layouts.
When it is possible to annotate an operation using both a `strided` and a
`splat` layout, we pick the `strided` layout. This is the correct choice when
propagating layouts down from parameters to the root; e.g.

```
? = add(strided, splat)
```

becomes

```
strided = add(strided, strided)
```

and requires a re-layout for the right-hand side argument.

The logic needs to be reversed to handle propagation in the opposite direction.
For example, code like

```
c : ?
use(c) : strided
use(c) : splat
```

should resolve to

```
c : splat
use(c) : strided
use(c) : splat
```

and incur a relayout in the `strided` use of `c`. This direction of propagation
is left as a `TODO` for now, to limit the amount of changes in a single commit.

PiperOrigin-RevId: 714056648
2025-01-10 08:10:57 -08:00
Benjamin Chetioui
3915f4a147 [Mosaic GPU] Commit to using Vectors everywhere (and no Tensors).
PiperOrigin-RevId: 707912637
2024-12-19 07:51:58 -08:00
Benjamin Chetioui
66ad2082ba [Mosaic GPU] Replace the dialect's layout enum with layouts holding the proper
sub-attributes.

PiperOrigin-RevId: 707846907
2024-12-19 02:59:26 -08:00
Benjamin Chetioui
036125544e [Mosaic GPU] Add layout inference and initial lowering for vector.{load,store}.
For now, the lowering only works for the strided fragmented layout. This is
mostly an exercise in plugging in lowering rules using `FragmentedArray`, and
will be expanded shortly.

PiperOrigin-RevId: 707031770
2024-12-17 03:51:01 -08:00
Benjamin Chetioui
2386838315 [Mosaic GPU] Fix layout inference traversal to traverse ops recursively.
PiperOrigin-RevId: 706136221
2024-12-13 23:51:20 -08:00
Benjamin Chetioui
4ef7706abb [Mosaic GPU] Split layout inference and dialect lowering files and tests.
PiperOrigin-RevId: 705100503
2024-12-11 07:31:34 -08:00