483 Commits

Author SHA1 Message Date
jax authors
f3b7c5cb9e Integrate LLVM at llvm/llvm-project@0230d63b4a
Updates LLVM usage to match
[0230d63b4a8b](https://github.com/llvm/llvm-project/commit/0230d63b4a8b)

PiperOrigin-RevId: 738222096
2025-03-18 19:23:20 -07:00
Adam Paszke
8da93249d2 [Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
This allows us to significantly simplify the generated PTX/SASS,
which is currently cluttered with LLVM trying to align slices to
start at bit 0 and failing to CSE the right shifts.

PiperOrigin-RevId: 737967890
2025-03-18 05:38:49 -07:00
Chris Jones
38d52a19ef [mosaic_gpu] Force flush all cupti activity, then unsubscribe.
With default flushing, it is possible for events to be missed. We should only unsubscribe after we are finished with cupti.

PiperOrigin-RevId: 737939327
2025-03-18 03:35:03 -07:00
Tzu-Wei Sung
e235fb9760 [Mosaic] Allow part of x2 int casts.
This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility.

PiperOrigin-RevId: 736710186
2025-03-13 18:57:36 -07:00
Tzu-Wei Sung
a0f1be123d [Mosaic] Improve error messages.
PiperOrigin-RevId: 736580673
2025-03-13 11:35:33 -07:00
Jevin Jiang
12c0987e2f [Mosaic TPU][NFC] Throw NYI error instead of crash when squeeze ref to 1d.
PiperOrigin-RevId: 736263705
2025-03-12 14:18:33 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
Jevin Jiang
0f0636afab [Mosaic TPU][Pallas] Add pl.reciprocal
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
Sergei Lebedev
928caf83ee [pallas:mosaic_gpu] copy_smem_to_gmem now allows skipping cp.async.commit_group
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.

PiperOrigin-RevId: 734553668
2025-03-07 07:43:54 -08:00
Jevin Jiang
ff4310f640 [Mosaic TPU] Support fp8 upcast to f32
PiperOrigin-RevId: 734345644
2025-03-06 17:19:15 -08:00
jax authors
a13b3cedad Merge pull request #26691 from h-vetinari:packed
PiperOrigin-RevId: 733696873
2025-03-05 05:46:01 -08:00
Tzu-Wei Sung
5179642eb5 [Mosaic] Rename dep name.
PiperOrigin-RevId: 732985217
2025-03-03 11:01:25 -08:00
Dimitar (Mitko) Asenov
3b305c6617 [Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op.
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.

PiperOrigin-RevId: 732618760
2025-03-02 03:17:13 -08:00
Benjamin Chetioui
a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
William S. Moses
8262987a1c Fix build dependencies
PiperOrigin-RevId: 731330542
2025-02-26 08:38:31 -08:00
Adam Paszke
cb7402f6de Remove MemoryEffects annotations from async_{load/store} ops
The annotation on async_load didn't indicate its write to SMEM, allowing it
to be DCEd by MLIR canonicalization. We don't get much mileage out of those
annotations, so let's just delete them for simplicity.

PiperOrigin-RevId: 731003033
2025-02-25 13:15:00 -08:00
jax authors
083ffd3717 [Easy][Mosaic] Tiny refactor for clarity in getTypeBitwidth
PiperOrigin-RevId: 730906329
2025-02-25 08:58:19 -08:00
H. Vetinari
91cae595e4 fix member access to packed CUDA struct 2025-02-24 08:03:07 +11:00
jax authors
b510127a13 Internal compatibility change
PiperOrigin-RevId: 729428478
2025-02-21 01:21:56 -08:00
jax authors
b7968474c2 [Pallas][Mosaic] Support float8_e4m3b11fnuz
PiperOrigin-RevId: 729169181
2025-02-20 10:44:33 -08:00
jax authors
37af0135b0 [Mosaic] Consider divisibility when doing large tiling
PiperOrigin-RevId: 728980108
2025-02-19 23:56:07 -08:00
Jevin Jiang
bb68124c33 [Mosaic TPU] Support mask concat
PiperOrigin-RevId: 728349788
2025-02-18 14:03:46 -08:00
jax authors
725087e13f Integrate LLVM at llvm/llvm-project@9d24f94379
Updates LLVM usage to match
[9d24f9437944](https://github.com/llvm/llvm-project/commit/9d24f9437944)

PiperOrigin-RevId: 728265165
2025-02-18 10:30:48 -08:00
jax authors
e78a469b42 Integrate LLVM at llvm/llvm-project@912b154f3a
Updates LLVM usage to match
[912b154f3a3f](https://github.com/llvm/llvm-project/commit/912b154f3a3f)

PiperOrigin-RevId: 727895384
2025-02-17 10:08:37 -08:00
Dimitar (Mitko) Asenov
52f8fbeee0 [Mosaic GPU] Implement lowerings for Tile and Transpose transforms from the MLIR dialect.
PiperOrigin-RevId: 727762334
2025-02-17 01:29:47 -08:00
jax authors
a6fcb7415f [TPU][Mosaic][Easy] Add verification for AssumeMultipleOp.
A user must use AssumeMultipleOp to annotate integer constants that are divisible by the given multiple.

PiperOrigin-RevId: 727699186
2025-02-16 21:16:05 -08:00
jax authors
eaceac3bf9 [Pallas] Reductions with replicated axes.
PiperOrigin-RevId: 727292293
2025-02-15 07:41:16 -08:00
Adam Paszke
a493df4dd8 Fix Windows build for Mosaic GPU extension
We only export symbols that being with `mlir` and a few other prefixes, so this renames our C API functions for consistency with that.

PiperOrigin-RevId: 726468092
2025-02-13 06:58:17 -08:00
Jevin Jiang
876668faa1 [Mosaic TPU] Support bf16 div if HW does not directly support.
PiperOrigin-RevId: 726212286
2025-02-12 15:04:09 -08:00
Dimitar (Mitko) Asenov
6fc1c61520 [Mosaic GPU] Use the memref layout to encode transforms (only swizzle for now).
Tile and Transpose transforms to follow.

PiperOrigin-RevId: 725716812
2025-02-11 11:51:25 -08:00
jax authors
ffd3faad72 [TPU[Mosaic] Fix missing sfences in smem DMAs
PiperOrigin-RevId: 725376627
2025-02-10 15:51:35 -08:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
Adam Paszke
e7a4f89343 [Mosaic TPU] Add optimized casts for bf16->s4 in TPUv6
PiperOrigin-RevId: 723455843
2025-02-05 04:21:55 -08:00
Jevin Jiang
d8b9211359 [Mosaic TPU] Support dynamic gather along axis 0 or 1 for 32-bit vreg-sized vector.
PiperOrigin-RevId: 721980453
2025-01-31 18:47:25 -08:00
Jevin Jiang
785a63ad0f [Mosaic TPU] Support non-32 bit mask relayout
PiperOrigin-RevId: 721552594
2025-01-30 16:13:23 -08:00
Tzu-Wei Sung
d4758b6d5e [Mosaic][NFC] Factor out xla-array related utils in a separate file.
Also added tests.

PiperOrigin-RevId: 721424194
2025-01-30 09:49:41 -08:00
Benjamin Chetioui
d8f3b33ae4 [Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.
We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
2025-01-30 09:08:45 -08:00
Dimitar (Mitko) Asenov
6214c25a6d [Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with explicit handling of parities
This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
2025-01-30 06:44:32 -08:00
Adam Paszke
29b658b358 [Mosaic TPU] Optimize clipping impelmentation in arith.fptosi
We can use maxf/minf to avoid extra comparisons

PiperOrigin-RevId: 720601304
2025-01-28 09:20:16 -08:00
Dimitar (Mitko) Asenov
a3a285dddc [Mosaic GPU] Handle the swizzle attribute in the lowering of async_store and async_load
PiperOrigin-RevId: 720129408
2025-01-27 05:18:16 -08:00
Sergei Lebedev
9ee7123c39 [mosaic_gpu] Fixed mosaic_gpu-serde pass registration
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.

PiperOrigin-RevId: 719280601
2025-01-24 06:35:54 -08:00
Adam Paszke
7043b852ec [Mosaic GPU] Add basic support for TMA with sub-byte types
PiperOrigin-RevId: 719240287
2025-01-24 03:54:12 -08:00
Jevin Jiang
8e1f956804 [Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op.
PiperOrigin-RevId: 719089676
2025-01-23 18:15:08 -08:00
Dimitar (Mitko) Asenov
f57d603c45 [Mosaic GPU] Simplify enums in the MLIR Mosaic GPU dialect.
This enables us to use them more simply in the current and upcoming Python code. The Python bindings for enum and enum attributes leave much to be desired.

PiperOrigin-RevId: 718795667
2025-01-23 03:38:26 -08:00
Dimitar (Mitko) Asenov
6b747b4109 [Mosaic GPU] Add a result to the WGMMA op definition in the MLIR dialect
PiperOrigin-RevId: 718788390
2025-01-23 03:10:07 -08:00
jax authors
6c76cc4e36 Integrate LLVM at llvm/llvm-project@d33e33fde7
Updates LLVM usage to match
[d33e33fde770](https://github.com/llvm/llvm-project/commit/d33e33fde770)

PiperOrigin-RevId: 718414171
2025-01-22 09:22:07 -08:00
jax authors
54bb7f5ddb Remove meaningless template keywords.
This will fix -Wmissing-template-arg-list-after-template-kw warnings.
This warning is error-by-default in Clang.

PiperOrigin-RevId: 718133601
2025-01-21 17:22:04 -08:00
Dimitar (Mitko) Asenov
f89accc56a [Mosaic GPU] Add support for converting all fragmented layouts to ir and back.
This will be used in the layout inference and lowering of the dialect WGMMA op

PiperOrigin-RevId: 717836648
2025-01-21 03:27:03 -08:00
Adam Paszke
543dd94762 [Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6
PiperOrigin-RevId: 717583425
2025-01-20 11:18:22 -08:00
Peter Hawkins
034e967e11 Remove CUDA rpaths from jaxlib build.
These are also set in the TSL build rules as part of the CUDA stub libraries, which these libraries depend on, so these copies of the rpath settings are redundant.

PiperOrigin-RevId: 716844265
2025-01-17 17:09:30 -08:00