Note that
* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
the dialect lowering AFAICT.
PiperOrigin-RevId: 736138554
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.
PiperOrigin-RevId: 734553668
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.
The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.
PiperOrigin-RevId: 734157829
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.
PiperOrigin-RevId: 732618760
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.
This is in preparation of changes implementing transform inference.
PiperOrigin-RevId: 732091266
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.
This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.
PiperOrigin-RevId: 726030853
- Change the `async_load` lowering to manage the single thread context.
- Use a predicate for the top-level arrive_expect. If we want to hide this further, we can have a warp-group level op that lowers to a single-threaded context.
PiperOrigin-RevId: 716219730
For now, the lowering only works for the strided fragmented layout. This is
mostly an exercise in plugging in lowering rules using `FragmentedArray`, and
will be expanded shortly.
PiperOrigin-RevId: 707031770
Layouts are added as annotations on MLIR ops, using the `in_layouts` and
`out_layouts` attributes.
At this point, layout inference is done in two passes: one "backwards" pass
(root-to-parameters), and one "forward" pass (parameters-to-root).
Each pass goes through all the ops in the specified order, and infers a
possible layout from the layout information that is available. We expect to
need two passes because partial layout annotations may be provided on
intermediate nodes (e.g. `wgmma`), and a single pass from the root to the
parameters is therefore insufficient to properly annotate all the operations.
We do not perform any check as to whether the inferred layouts can be further
lowered correctly---meaning that the produced IR can possibly fail to lower
later.
Layouts are only inferred for ops involving at least one operand or result of
type `VectorType`/`RankedTensorType`.
When layouts can't be inferred for an op that should have them, we default to
annotating it with strided fragmented layouts.
PiperOrigin-RevId: 705092403
This corresponds to what's implemented in `BarrierRef`, and ultimately makes it
easier to allocate barriers at a specific address in dynamic shared memory.
PiperOrigin-RevId: 695308297