119 Commits

Author SHA1 Message Date
Peter Hawkins
3fa557289a Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
2025-01-08 08:14:50 -08:00
Dimitar (Mitko) Asenov
dad23fed09 [Mosaic GPU] Add a lowering for simple async_load and async_store ops.
Only untransformed and unsliced loads/stores are supported for now. The rest will be a follow up.

PiperOrigin-RevId: 708347442
2024-12-20 09:38:13 -08:00
Benjamin Chetioui
3915f4a147 [Mosaic GPU] Commit to using Vectors everywhere (and no Tensors).
PiperOrigin-RevId: 707912637
2024-12-19 07:51:58 -08:00
Adam Paszke
006c65d8d4 [Mosaic GPU] Add a new tiled layout, optimized for upcasting before WGMMA
PiperOrigin-RevId: 707860467
2024-12-19 04:03:24 -08:00
Benjamin Chetioui
66ad2082ba [Mosaic GPU] Replace the dialect's layout enum with layouts holding the proper
sub-attributes.

PiperOrigin-RevId: 707846907
2024-12-19 02:59:26 -08:00
Benjamin Chetioui
6a03ea3e73 [Mosaic GPU] Clean up imports in gpu_dialect_test.py.
PiperOrigin-RevId: 707549269
2024-12-18 07:49:53 -08:00
Sergei Lebedev
98067fc10e [mosaic_gpu] Error on static out of bounds indices in utils.parse_indices
It would also be nice to optionally insert runtime assertions for dynamic
indices, but we don't have a way of doing that just yet.

PiperOrigin-RevId: 707532787
2024-12-18 06:47:31 -08:00
Andrey Portnoy
6ea4708214 [Mosaic GPU] Skip testing uint64 unless 64-bit types are enabled 2024-12-17 18:28:29 -05:00
Benjamin Chetioui
36b12d58f4 [Mosaic GPU] Add end-to-end lowering example for a pointwise kernel using the dialect and layout inference.
Also implement a lowering rule for `arith.AddFOp`.

PiperOrigin-RevId: 707131747
2024-12-17 09:28:05 -08:00
Benjamin Chetioui
036125544e [Mosaic GPU] Add layout inference and initial lowering for vector.{load,store}.
For now, the lowering only works for the strided fragmented layout. This is
mostly an exercise in plugging in lowering rules using `FragmentedArray`, and
will be expanded shortly.

PiperOrigin-RevId: 707031770
2024-12-17 03:51:01 -08:00
Sergei Lebedev
ee7226d564 [mosaic_gpu] Allow calling reduce_sum on a fragmented array in splat layout
PiperOrigin-RevId: 706668018
2024-12-16 05:04:10 -08:00
Benjamin Chetioui
2386838315 [Mosaic GPU] Fix layout inference traversal to traverse ops recursively.
PiperOrigin-RevId: 706136221
2024-12-13 23:51:20 -08:00
Benjamin Chetioui
4ef7706abb [Mosaic GPU] Split layout inference and dialect lowering files and tests.
PiperOrigin-RevId: 705100503
2024-12-11 07:31:34 -08:00
Benjamin Chetioui
07a3515065 [Mosaic GPU] Add an initial skeleton for a layout inference pass.
Layouts are added as annotations on MLIR ops, using the `in_layouts` and
`out_layouts` attributes.

At this point, layout inference is done in two passes: one "backwards" pass
(root-to-parameters), and one "forward" pass (parameters-to-root).

Each pass goes through all the ops in the specified order, and infers a
possible layout from the layout information that is available. We expect to
need two passes because partial layout annotations may be provided on
intermediate nodes (e.g. `wgmma`), and a single pass from the root to the
parameters is therefore insufficient to properly annotate all the operations.

We do not perform any check as to whether the inferred layouts can be further
lowered correctly---meaning that the produced IR can possibly fail to lower
later.

Layouts are only inferred for ops involving at least one operand or result of
type `VectorType`/`RankedTensorType`.

When layouts can't be inferred for an op that should have them, we default to
annotating it with strided fragmented layouts.

PiperOrigin-RevId: 705092403
2024-12-11 07:01:06 -08:00
jax authors
0d7eaeb5d8 Merge pull request #24805 from andportnoy:aportnoy/mosaic-gpu-cupti-profiler
PiperOrigin-RevId: 705071782
2024-12-11 05:29:10 -08:00
Dimitar (Mitko) Asenov
3d9c720d42 [Mosaic GPU] Automatically format the Mosaic GPU dialect test python code
This allows me to keep using the formatter going forward and not have to bother manually formatting code.

PiperOrigin-RevId: 705024602
2024-12-11 02:04:08 -08:00
Dimitar (Mitko) Asenov
66f45d039f [Mosaic GPU] Add WGMMA to the Mosaic GPU MLIR Dialect.
The op API is still in flux so I'm leaving some of the verification code untested.

PiperOrigin-RevId: 705020066
2024-12-11 01:47:29 -08:00
Andrey Portnoy
cc22334c21 [Mosaic GPU] Add CUPTI profiler alongside events-based implementation 2024-12-09 14:31:20 -05:00
Sergei Lebedev
bae660002a [pallas:mosaic_gpu] FragmentedArray.reduce_sum now returns a FragmentedArray
This aligns it with the `reduce` method and also makes it clear that the
reduction always produces a scalar.

PiperOrigin-RevId: 703494443
2024-12-06 07:32:21 -08:00
Adam Paszke
11090be0b3 [Mosaic GPU] Add an optimization barrier
The barrier is a no-op at runtime, but appears as a side-effecting op to LLVM
which prevents it from moving the (even pure) computations that involve the
supplied arrays past the barrier.

PiperOrigin-RevId: 702709125
2024-12-04 06:54:48 -08:00
Adam Paszke
bd66f5280b [Mosaic GPU] Add a bank-conflict checker to tiled transfer + transfer planner
Instead of only allowing a fixed set of layouts that we've hand verified as
bank-conflict free, we now simulate the transactions performed within each
warp and verify that no bank conflicts happen. If we detect that the simple
schedule does not work out, we attempt to partition the threads in a warp
into two groups and stagger the transfers in a way that lets us avoid conflicts.

This allows us to match the hand-designed transfer schedule I wrote for 32-bit
types, and even generalizes it to more cases automatically (e.g. swizzle=32).

PiperOrigin-RevId: 701919158
2024-12-02 04:27:01 -08:00
Christos Perivolaropoulos
ea69401eff [mgpu] Fixed off-by-one issue in pointwise argument shuffling when leading argument is splat.
Also adapted the test to catch a possible regression. The issue appeared in >2 operands.

PiperOrigin-RevId: 701306731
2024-11-29 09:35:44 -08:00
Adam Paszke
b801539f5c [Pallas][Mosaic GPU] Add support for compressing squeezed dims in async_copy + grid fixes
This change removes the need to flatten the batch dimension into sequence dimensions
in the flash attention kernel. The critical thing here is the observation that we can
in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting
us reduce its rank when necessary.

Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU
lowering, which I've fixed.

PiperOrigin-RevId: 701035277
2024-11-28 08:35:00 -08:00
Adam Paszke
b09b0779e0 [Mosaic GPU] Add support for fast upcasts of s8 to bf16 for vectors of 4 elements
To complement the current path that only handles 2 elements.

PiperOrigin-RevId: 700998965
2024-11-28 05:41:08 -08:00
Christos Perivolaropoulos
f3acfa93bb [mgpu] FragentedArray.foreach() can now optionally return a new array
PiperOrigin-RevId: 700708119
2024-11-27 08:20:49 -08:00
Christos Perivolaropoulos
f828f2d7d0 [mgpu] Pointwise min
PiperOrigin-RevId: 700175724
2024-11-25 19:13:51 -08:00
Christos Perivolaropoulos
c5dc980db8 [mgpu/pallas_mgpu] Pointwise tanh support
PiperOrigin-RevId: 700158250
2024-11-25 17:56:11 -08:00
Peter Buchlovsky
69e3f0d37d [pallas:mosaic_gpu] Add test for FragmentedArray.bitcast.
PiperOrigin-RevId: 699919048
2024-11-25 03:30:57 -08:00
Christos Perivolaropoulos
1d2dc17e5f [mgpu] Pointwise op can handle LHS splats.
PiperOrigin-RevId: 698818035
2024-11-21 09:50:21 -08:00
Sergei Lebedev
f442d40f92 [mosaic_gpu] Fixed FragmentedArray comparisons with literals
PiperOrigin-RevId: 698343858
2024-11-20 04:31:28 -08:00
jax authors
a889a95aa1 Merge pull request #24839 from andportnoy:aportnoy/mosaic-gpu-hopper-tests
PiperOrigin-RevId: 695388748
2024-11-11 10:12:29 -08:00
Andrey Portnoy
24af8a676b [Mosaic GPU] Only run tests requiring sm90a on Hopper 2024-11-11 12:02:48 -05:00
Benjamin Chetioui
8a7bf2e4b0 [Mosaic GPU] Ensure that lowering InitializeBarrierOp preserves the result's type.
Otherwise, the lowered IR won't be type-correct.

PiperOrigin-RevId: 695339726
2024-11-11 08:02:07 -08:00
Benjamin Chetioui
da89c9e38c [Mosaic GPU] Add base_pointer argument to InitializeBarrierOp.
This corresponds to what's implemented in `BarrierRef`, and ultimately makes it
easier to allocate barriers at a specific address in dynamic shared memory.

PiperOrigin-RevId: 695308297
2024-11-11 06:18:26 -08:00
Dimitar (Mitko) Asenov
d833066a1f [MOSAIC:GPU] Add async_load, async_store, and supporting attributes to the MLIR Mosaic GPU Dialect.
PiperOrigin-RevId: 694643777
2024-11-08 14:34:23 -08:00
Adam Paszke
6a124ac554 [Mosaic GPU] Implement tiled and swizzled transfers for tiled layouts
PiperOrigin-RevId: 694449664
2024-11-08 04:38:20 -08:00
Christos Perivolaropoulos
5e43220e97 [mosaic_gpu] Scalar arguments to kernels.
PiperOrigin-RevId: 694426328
2024-11-08 02:59:15 -08:00
Benjamin Chetioui
1f1d27de2f [Mosaic GPU] Implement the skeleton of a lowering pass for the Mosaic GPU dialect.
Also add a lowering rule for `mosaic_gpu.initialize_barrier`.

PiperOrigin-RevId: 694276698
2024-11-07 15:58:04 -08:00
Adam Paszke
de06584d98 [Mosaic GPU] Introduce a more flexible layout system
So far all of our layouts have been tailored to a limited set of use
cases we've tried so far, but they're still not general enough to
handle all of the register layouts needed for WGMMA or mixed precision
matmuls (incl. intermediate steps during conversions). Instead of adding
more special cases, I decided to adopt XLA tiled layouts and they do seem
to work quite well!

This change only lays the groundwork for the new layout system. Future
changes will build upon them to add new features and eventually replace
`WGMMA_LAYOUT` altogether.

PiperOrigin-RevId: 694105514
2024-11-07 07:08:51 -08:00
Adam Paszke
506671291a [Mosaic GPU] Fix the ordering of transforms in async_copy
Previously we didn't really fully discharge squeezing the indexed
dims before applying other GMEM transforms, leading to potential
failures because they were not anticipating the increased rank.

PiperOrigin-RevId: 694098739
2024-11-07 06:41:42 -08:00
Benjamin Chetioui
63e59c5fd7 [Mosaic GPU] Ensure that the dialect module can be loaded successfully.
This requires that the file providing the bindings has the same name as the
dialect it defines, since dialect search looks for a module path of the form
`<prefix>.<dialect namespace>`.

PiperOrigin-RevId: 693241875
2024-11-05 00:47:21 -08:00
Benjamin Chetioui
c708a04c6e [Mosaic GPU] Add Python bindings for the Mosaic GPU MLIR dialect.
Also start moving the existing C++ tests to Python.

PiperOrigin-RevId: 691729887
2024-10-31 02:47:30 -07:00
Sergei Lebedev
2652ab5608 [mosaic_gpu] Added support for bitwise and, or and xor to FragmentedArray
PiperOrigin-RevId: 691411447
2024-10-30 07:30:48 -07:00
Sergei Lebedev
04bdd07f66 [mosaic_gpu] mgpu.FragmentedArray now supports //
This is needed to compute grid index from the iteration step counter in `emit_pipeline`.

PiperOrigin-RevId: 690608581
2024-10-28 07:52:22 -07:00
Adam Paszke
343cf18e09 [Pallas:MGPU] Wire up the Mosaic GPU profiler into Pallas
PiperOrigin-RevId: 690574747
2024-10-28 05:40:08 -07:00
Sergei Lebedev
bb271aaff8 [pallas:mosaic_gpu] Added FragmentedArray.to_layout
PiperOrigin-RevId: 686524192
2024-10-16 08:53:02 -07:00
Christos Perivolaropoulos
28b0934272 [mosaic_gpu] Memref reshape by means of folding/unfolding
A reshape function that does fold/unfold by touching minimal number of
dimensions to potentially circumvent issues with strided memrefs.

PiperOrigin-RevId: 683663541
2024-10-08 09:59:54 -07:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Sergei Lebedev
a373e37be2 Fixed mgpu.FragmentedArray.reduce_sum for integer types
The implementation previously assumed the type is floating and used addf.

PiperOrigin-RevId: 678718871
2024-09-25 08:50:24 -07:00
Peter Hawkins
a43c7f2ace Enable more H100 tests in CI.
Rename "gpu" config CI tag to "gpu_v100".

PiperOrigin-RevId: 678695003
2024-09-25 07:37:48 -07:00