471 Commits

Author SHA1 Message Date
Tzu-Wei Sung
5179642eb5 [Mosaic] Rename dep name.
PiperOrigin-RevId: 732985217
2025-03-03 11:01:25 -08:00
Dimitar (Mitko) Asenov
3b305c6617 [Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op.
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.

PiperOrigin-RevId: 732618760
2025-03-02 03:17:13 -08:00
Benjamin Chetioui
a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
William S. Moses
8262987a1c Fix build dependencies
PiperOrigin-RevId: 731330542
2025-02-26 08:38:31 -08:00
Adam Paszke
cb7402f6de Remove MemoryEffects annotations from async_{load/store} ops
The annotation on async_load didn't indicate its write to SMEM, allowing it
to be DCEd by MLIR canonicalization. We don't get much mileage out of those
annotations, so let's just delete them for simplicity.

PiperOrigin-RevId: 731003033
2025-02-25 13:15:00 -08:00
jax authors
083ffd3717 [Easy][Mosaic] Tiny refactor for clarity in getTypeBitwidth
PiperOrigin-RevId: 730906329
2025-02-25 08:58:19 -08:00
jax authors
b510127a13 Internal compatibility change
PiperOrigin-RevId: 729428478
2025-02-21 01:21:56 -08:00
jax authors
b7968474c2 [Pallas][Mosaic] Support float8_e4m3b11fnuz
PiperOrigin-RevId: 729169181
2025-02-20 10:44:33 -08:00
jax authors
37af0135b0 [Mosaic] Consider divisibility when doing large tiling
PiperOrigin-RevId: 728980108
2025-02-19 23:56:07 -08:00
Jevin Jiang
bb68124c33 [Mosaic TPU] Support mask concat
PiperOrigin-RevId: 728349788
2025-02-18 14:03:46 -08:00
jax authors
725087e13f Integrate LLVM at llvm/llvm-project@9d24f94379
Updates LLVM usage to match
[9d24f9437944](https://github.com/llvm/llvm-project/commit/9d24f9437944)

PiperOrigin-RevId: 728265165
2025-02-18 10:30:48 -08:00
jax authors
e78a469b42 Integrate LLVM at llvm/llvm-project@912b154f3a
Updates LLVM usage to match
[912b154f3a3f](https://github.com/llvm/llvm-project/commit/912b154f3a3f)

PiperOrigin-RevId: 727895384
2025-02-17 10:08:37 -08:00
Dimitar (Mitko) Asenov
52f8fbeee0 [Mosaic GPU] Implement lowerings for Tile and Transpose transforms from the MLIR dialect.
PiperOrigin-RevId: 727762334
2025-02-17 01:29:47 -08:00
jax authors
a6fcb7415f [TPU][Mosaic][Easy] Add verification for AssumeMultipleOp.
A user must use AssumeMultipleOp to annotate integer constants that are divisible by the given multiple.

PiperOrigin-RevId: 727699186
2025-02-16 21:16:05 -08:00
jax authors
eaceac3bf9 [Pallas] Reductions with replicated axes.
PiperOrigin-RevId: 727292293
2025-02-15 07:41:16 -08:00
Adam Paszke
a493df4dd8 Fix Windows build for Mosaic GPU extension
We only export symbols that being with `mlir` and a few other prefixes, so this renames our C API functions for consistency with that.

PiperOrigin-RevId: 726468092
2025-02-13 06:58:17 -08:00
Jevin Jiang
876668faa1 [Mosaic TPU] Support bf16 div if HW does not directly support.
PiperOrigin-RevId: 726212286
2025-02-12 15:04:09 -08:00
Dimitar (Mitko) Asenov
6fc1c61520 [Mosaic GPU] Use the memref layout to encode transforms (only swizzle for now).
Tile and Transpose transforms to follow.

PiperOrigin-RevId: 725716812
2025-02-11 11:51:25 -08:00
jax authors
ffd3faad72 [TPU[Mosaic] Fix missing sfences in smem DMAs
PiperOrigin-RevId: 725376627
2025-02-10 15:51:35 -08:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
Adam Paszke
e7a4f89343 [Mosaic TPU] Add optimized casts for bf16->s4 in TPUv6
PiperOrigin-RevId: 723455843
2025-02-05 04:21:55 -08:00
Jevin Jiang
d8b9211359 [Mosaic TPU] Support dynamic gather along axis 0 or 1 for 32-bit vreg-sized vector.
PiperOrigin-RevId: 721980453
2025-01-31 18:47:25 -08:00
Jevin Jiang
785a63ad0f [Mosaic TPU] Support non-32 bit mask relayout
PiperOrigin-RevId: 721552594
2025-01-30 16:13:23 -08:00
Tzu-Wei Sung
d4758b6d5e [Mosaic][NFC] Factor out xla-array related utils in a separate file.
Also added tests.

PiperOrigin-RevId: 721424194
2025-01-30 09:49:41 -08:00
Benjamin Chetioui
d8f3b33ae4 [Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.
We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
2025-01-30 09:08:45 -08:00
Dimitar (Mitko) Asenov
6214c25a6d [Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with explicit handling of parities
This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
2025-01-30 06:44:32 -08:00
Adam Paszke
29b658b358 [Mosaic TPU] Optimize clipping impelmentation in arith.fptosi
We can use maxf/minf to avoid extra comparisons

PiperOrigin-RevId: 720601304
2025-01-28 09:20:16 -08:00
Dimitar (Mitko) Asenov
a3a285dddc [Mosaic GPU] Handle the swizzle attribute in the lowering of async_store and async_load
PiperOrigin-RevId: 720129408
2025-01-27 05:18:16 -08:00
Sergei Lebedev
9ee7123c39 [mosaic_gpu] Fixed mosaic_gpu-serde pass registration
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.

PiperOrigin-RevId: 719280601
2025-01-24 06:35:54 -08:00
Adam Paszke
7043b852ec [Mosaic GPU] Add basic support for TMA with sub-byte types
PiperOrigin-RevId: 719240287
2025-01-24 03:54:12 -08:00
Jevin Jiang
8e1f956804 [Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op.
PiperOrigin-RevId: 719089676
2025-01-23 18:15:08 -08:00
Dimitar (Mitko) Asenov
f57d603c45 [Mosaic GPU] Simplify enums in the MLIR Mosaic GPU dialect.
This enables us to use them more simply in the current and upcoming Python code. The Python bindings for enum and enum attributes leave much to be desired.

PiperOrigin-RevId: 718795667
2025-01-23 03:38:26 -08:00
Dimitar (Mitko) Asenov
6b747b4109 [Mosaic GPU] Add a result to the WGMMA op definition in the MLIR dialect
PiperOrigin-RevId: 718788390
2025-01-23 03:10:07 -08:00
jax authors
6c76cc4e36 Integrate LLVM at llvm/llvm-project@d33e33fde7
Updates LLVM usage to match
[d33e33fde770](https://github.com/llvm/llvm-project/commit/d33e33fde770)

PiperOrigin-RevId: 718414171
2025-01-22 09:22:07 -08:00
jax authors
54bb7f5ddb Remove meaningless template keywords.
This will fix -Wmissing-template-arg-list-after-template-kw warnings.
This warning is error-by-default in Clang.

PiperOrigin-RevId: 718133601
2025-01-21 17:22:04 -08:00
Dimitar (Mitko) Asenov
f89accc56a [Mosaic GPU] Add support for converting all fragmented layouts to ir and back.
This will be used in the layout inference and lowering of the dialect WGMMA op

PiperOrigin-RevId: 717836648
2025-01-21 03:27:03 -08:00
Adam Paszke
543dd94762 [Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6
PiperOrigin-RevId: 717583425
2025-01-20 11:18:22 -08:00
Peter Hawkins
034e967e11 Remove CUDA rpaths from jaxlib build.
These are also set in the TSL build rules as part of the CUDA stub libraries, which these libraries depend on, so these copies of the rpath settings are redundant.

PiperOrigin-RevId: 716844265
2025-01-17 17:09:30 -08:00
jax authors
a527aba646 Reverts f1b894d14a28ac22a037fb79177b991275c75a18
PiperOrigin-RevId: 716653711
2025-01-17 07:00:31 -08:00
Benjamin Chetioui
d3be190efb [Mosaic GPU] Delete unused declarations of mosaic_gpu_memcpy_async_h2d.
PiperOrigin-RevId: 716616807
2025-01-17 04:34:48 -08:00
Sergei Lebedev
d34c40f6b6 [mosaic_gpu] Added a serialization pass
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
2025-01-17 03:12:51 -08:00
Adam Paszke
bd22bfef71 [Mosaic TPU] Use large to compact 2nd minor retiling for conversions going both ways
This specific retiling is its own inverse and it faster than alternatives.

PiperOrigin-RevId: 716360070
2025-01-16 13:35:26 -08:00
Tzu-Wei Sung
5c020ee317 [Mosaic] Fix infer/apply extensions.
1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference.
2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp.

PiperOrigin-RevId: 716274818
2025-01-16 09:57:14 -08:00
Sergei Lebedev
4221f109d1 [mosaic] Extracted serialization pass traversal logic into a reusable function
I will use it to implement Mosaic GPU serialization pass in a follow up.

PiperOrigin-RevId: 716156650
2025-01-16 02:58:06 -08:00
Tzu-Wei Sung
4a9cc9ffc1 [Mosaic] Allow passing ApplyVectorLayoutCtx to tpu.apply_layout_op.
To make it the same with C++ API. While I'm here, fix a bug in test_concatenate.

PiperOrigin-RevId: 716016244
2025-01-15 17:47:36 -08:00
Naums Mogers
d3ba1eb339 [Mosaic] Add a macro to convert abseil StatusOr to LLVM FailureOr
PiperOrigin-RevId: 715943314
2025-01-15 14:19:29 -08:00
George Necula
f1b894d14a Reverts 391bad8ff59c07c8fad7b8ce05cd0e29dee4cf1a
PiperOrigin-RevId: 715435319
2025-01-14 10:31:59 -08:00
Ayaka
9ba1fd2801 [Pallas TPU] Add vector support to pl.debug_print
PiperOrigin-RevId: 715085454
2025-01-13 13:22:21 -08:00
Adam Paszke
391bad8ff5 [Mosaic TPU] Add support for arith.fptosi with non-32bit source and target types
This effectively moves some of the Pallas logic to the layer below.

PiperOrigin-RevId: 714965374
2025-01-13 07:49:13 -08:00
Tomás Longeri
7852045582 [Mosaic TPU] Enable non-sublane-aligned bf16 2D load/stores for earlier TPU gens
It is still not efficiently implemented, this is mostly to clean up some logic. We may be able to fuse the creation of masks for different tiles into the creation of a single one. But this is also a problem for the later gens.

This also cleans up an unreachable return statement.

PiperOrigin-RevId: 714847066
2025-01-12 23:58:40 -08:00