49 Commits

Author SHA1 Message Date
Sergei Lebedev
a7e5eaee56 [pallas:mosaic_gpu] jnp.reduce_sum now works for >1D arrays
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00
Christos Perivolaropoulos
b34f56bfd7 [mosaic_gpu/pallas:mgpu] Eradicate wgmma_layout
PiperOrigin-RevId: 736187550
2025-03-12 10:47:48 -07:00
Sergei Lebedev
e33f3fc48b [pallas:mosaic_gpu] Added support for reductions to the WG lowering
Note that

* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
  the dialect lowering AFAICT.

PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
Dimitar (Mitko) Asenov
d2bf034c47 [Mosaic GPU] Test the wgmma_op lowering when a is in registers.
I had to add support for wgmma layout in vector_load. Not sure if this is useful outside the test.

PiperOrigin-RevId: 735384104
2025-03-10 08:25:43 -07:00
Sergei Lebedev
91340ea0a7 [pallas:mosaic_gpu] Added support for math functions to the WG lowering
PiperOrigin-RevId: 735333893
2025-03-10 05:08:19 -07:00
Benjamin Chetioui
75d8702023 [Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
Enable some of the pre-existing Pallas `ops_test`s for testing.

PiperOrigin-RevId: 735293084
2025-03-10 02:14:39 -07:00
Sergei Lebedev
928caf83ee [pallas:mosaic_gpu] copy_smem_to_gmem now allows skipping cp.async.commit_group
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.

PiperOrigin-RevId: 734553668
2025-03-07 07:43:54 -08:00
Dimitar (Mitko) Asenov
5d64b3d2dd [Mosaic GPU] Fix scf.ForOp lowering to put lowered ops at the right place.
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.

The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.

PiperOrigin-RevId: 734157829
2025-03-06 08:40:19 -08:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Dimitar (Mitko) Asenov
3b305c6617 [Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op.
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.

PiperOrigin-RevId: 732618760
2025-03-02 03:17:13 -08:00
Dimitar (Mitko) Asenov
c60ef5a2a1 [Mosaic GPU] Wire up the slice_lengths and indices operands in lowering of the MLIR dialect.
This enables slicing via TMA and is needed for pipelining.

PiperOrigin-RevId: 732613803
2025-03-02 02:43:47 -08:00
Benjamin Chetioui
1bc36e623b [Mosaic GPU][NFC] Delete workaround for dialect bindings before jaxlib 0.5.1.
PiperOrigin-RevId: 732102282
2025-02-28 05:25:53 -08:00
Benjamin Chetioui
a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
Benjamin Chetioui
abfe2d080e [Mosaic GPU][NFC] Move some functions to a new file called inference_utils.py.
The intent is to move utils that are useful for both layout inference and
transform inference to a shared location.

PiperOrigin-RevId: 732067659
2025-02-28 03:02:59 -08:00
Adam Paszke
99a12ef9ea [Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Adam Paszke
1de2f839d5 [Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in MGPU lowering
PiperOrigin-RevId: 731253431
2025-02-26 04:03:57 -08:00
Adam Paszke
ced28167e8 [Mosaic GPU] Use explicit recursion in rules instead of doing it automatically
Control-flow ops that have vector inputs or outputs will need to be specially adjusted.

PiperOrigin-RevId: 730922072
2025-02-25 09:44:57 -08:00
Benjamin Chetioui
5b13883f8e [Mosaic GPU] Add dialect lowering logic for splat constants.
PiperOrigin-RevId: 730842871
2025-02-25 05:25:56 -08:00
Sergei Lebedev
7eadc64b5a [pallas:mosaic_gpu] Added WG lowering rules for TMA primitives and run_scoped_p
PiperOrigin-RevId: 730780335
2025-02-25 01:32:43 -08:00
Sergei Lebedev
74b2e0203f [pallas:mosaic_gpu] Use {min,max}imumf instead of {min,max}numf
PiperOrigin-RevId: 730154865
2025-02-23 09:52:48 -08:00
Sergei Lebedev
7438976e79 [pallas:mosaic_gpu] Added support for binary/comparison ops with WG semantics
PiperOrigin-RevId: 729266484
2025-02-20 15:06:27 -08:00
Dimitar (Mitko) Asenov
52f8fbeee0 [Mosaic GPU] Implement lowerings for Tile and Transpose transforms from the MLIR dialect.
PiperOrigin-RevId: 727762334
2025-02-17 01:29:47 -08:00
Adam Paszke
cdcf35fd70 Remove an unused import
PiperOrigin-RevId: 726910300
2025-02-14 06:49:34 -08:00
Andrey Portnoy
ea5eb49aa9 [Mosaic GPU] Use gettatr to import version-specific dialect ops 2025-02-13 15:08:42 -05:00
Benjamin Chetioui
c7199fe8a5 [Pallas/Mosaic GPU] Enable progressive lowering for integer addition.
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.

This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.

PiperOrigin-RevId: 726030853
2025-02-12 06:29:25 -08:00
Dimitar (Mitko) Asenov
6fc1c61520 [Mosaic GPU] Use the memref layout to encode transforms (only swizzle for now).
Tile and Transpose transforms to follow.

PiperOrigin-RevId: 725716812
2025-02-11 11:51:25 -08:00
Benjamin Chetioui
d8f3b33ae4 [Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.
We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
2025-01-30 09:08:45 -08:00
Dimitar (Mitko) Asenov
6214c25a6d [Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with explicit handling of parities
This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
2025-01-30 06:44:32 -08:00
Benjamin Chetioui
46512e684b [Mosaic GPU][NFC] Fix wrong type annotations, and do some NFC cleanups.
PiperOrigin-RevId: 721350296
2025-01-30 05:13:58 -08:00
Dimitar (Mitko) Asenov
d9f67ffe13 [Mosaic GPU] Implement a lowering for the dialect WGMMA op
PiperOrigin-RevId: 720663200
2025-01-28 12:08:45 -08:00
Dimitar (Mitko) Asenov
a0db6c5cf6 [Mosaic GPU] Use a single instance of the single_thread_predicate in the MLIR dialect lowering.
PiperOrigin-RevId: 720155654
2025-01-27 07:04:06 -08:00
Dimitar (Mitko) Asenov
a3a285dddc [Mosaic GPU] Handle the swizzle attribute in the lowering of async_store and async_load
PiperOrigin-RevId: 720129408
2025-01-27 05:18:16 -08:00
Dimitar (Mitko) Asenov
6f609926a6 [Mosaic GPU] Remove an unnecessary restriction in the vector.store lowering
This was made obsolete by:
f89accc56a

PiperOrigin-RevId: 718808561
2025-01-23 04:24:14 -08:00
Dimitar (Mitko) Asenov
f89accc56a [Mosaic GPU] Add support for converting all fragmented layouts to ir and back.
This will be used in the layout inference and lowering of the dialect WGMMA op

PiperOrigin-RevId: 717836648
2025-01-21 03:27:03 -08:00
Dimitar (Mitko) Asenov
5e27efd0e0 [MosaicGPU] Cleanup imports in dialect_lowering.py
PiperOrigin-RevId: 716244938
2025-01-16 08:26:02 -08:00
Dimitar (Mitko) Asenov
24884071b9 [MosaicGPU] Remove the single_thread context from top-level dialect code.
- Change the `async_load` lowering to manage the single thread context.
- Use a predicate for the top-level arrive_expect. If we want to hide this further, we can have a warp-group level op that lowers to a single-threaded context.

PiperOrigin-RevId: 716219730
2025-01-16 06:59:32 -08:00
Dimitar (Mitko) Asenov
ce03cf976e [MosaicGPU] Move gpu_address_space_to_nvptx inside utils.py and use it.
PiperOrigin-RevId: 716214822
2025-01-16 06:41:51 -08:00
Dimitar (Mitko) Asenov
22417ae28e [MosaicGPU] Extract code into a new method BarrierRef.from_dialect_barrier_memref and implement support for 1D barrier memrefs.
PiperOrigin-RevId: 716180182
2025-01-16 04:30:43 -08:00
Dimitar (Mitko) Asenov
dad23fed09 [Mosaic GPU] Add a lowering for simple async_load and async_store ops.
Only untransformed and unsliced loads/stores are supported for now. The rest will be a follow up.

PiperOrigin-RevId: 708347442
2024-12-20 09:38:13 -08:00
Benjamin Chetioui
66ad2082ba [Mosaic GPU] Replace the dialect's layout enum with layouts holding the proper
sub-attributes.

PiperOrigin-RevId: 707846907
2024-12-19 02:59:26 -08:00
Benjamin Chetioui
36b12d58f4 [Mosaic GPU] Add end-to-end lowering example for a pointwise kernel using the dialect and layout inference.
Also implement a lowering rule for `arith.AddFOp`.

PiperOrigin-RevId: 707131747
2024-12-17 09:28:05 -08:00
Benjamin Chetioui
036125544e [Mosaic GPU] Add layout inference and initial lowering for vector.{load,store}.
For now, the lowering only works for the strided fragmented layout. This is
mostly an exercise in plugging in lowering rules using `FragmentedArray`, and
will be expanded shortly.

PiperOrigin-RevId: 707031770
2024-12-17 03:51:01 -08:00
Benjamin Chetioui
4ef7706abb [Mosaic GPU] Split layout inference and dialect lowering files and tests.
PiperOrigin-RevId: 705100503
2024-12-11 07:31:34 -08:00
Benjamin Chetioui
07a3515065 [Mosaic GPU] Add an initial skeleton for a layout inference pass.
Layouts are added as annotations on MLIR ops, using the `in_layouts` and
`out_layouts` attributes.

At this point, layout inference is done in two passes: one "backwards" pass
(root-to-parameters), and one "forward" pass (parameters-to-root).

Each pass goes through all the ops in the specified order, and infers a
possible layout from the layout information that is available. We expect to
need two passes because partial layout annotations may be provided on
intermediate nodes (e.g. `wgmma`), and a single pass from the root to the
parameters is therefore insufficient to properly annotate all the operations.

We do not perform any check as to whether the inferred layouts can be further
lowered correctly---meaning that the produced IR can possibly fail to lower
later.

Layouts are only inferred for ops involving at least one operand or result of
type `VectorType`/`RankedTensorType`.

When layouts can't be inferred for an op that should have them, we default to
annotating it with strided fragmented layouts.

PiperOrigin-RevId: 705092403
2024-12-11 07:01:06 -08:00
Benjamin Chetioui
8a7bf2e4b0 [Mosaic GPU] Ensure that lowering InitializeBarrierOp preserves the result's type.
Otherwise, the lowered IR won't be type-correct.

PiperOrigin-RevId: 695339726
2024-11-11 08:02:07 -08:00
Benjamin Chetioui
da89c9e38c [Mosaic GPU] Add base_pointer argument to InitializeBarrierOp.
This corresponds to what's implemented in `BarrierRef`, and ultimately makes it
easier to allocate barriers at a specific address in dynamic shared memory.

PiperOrigin-RevId: 695308297
2024-11-11 06:18:26 -08:00
Dan Foreman-Mackey
4a365670f7 Fix pre-commit to run on all files in CI. 2024-11-08 13:47:27 -05:00
Benjamin Chetioui
1f1d27de2f [Mosaic GPU] Implement the skeleton of a lowering pass for the Mosaic GPU dialect.
Also add a lowering rule for `mosaic_gpu.initialize_barrier`.

PiperOrigin-RevId: 694276698
2024-11-07 15:58:04 -08:00