314 Commits

Author SHA1 Message Date
jax authors
81a95f78b9 [Mosaic] Parameterize the number of lanes and sublanes in TPU dialects.
PiperOrigin-RevId: 684392184
2024-10-10 04:28:36 -07:00
Jevin Jiang
f52b016de1 [Mosaic TPU] Change getLayout to force offset to 0 when inferring input has offset out of the first tile.
PiperOrigin-RevId: 684145987
2024-10-09 13:11:49 -07:00
Jevin Jiang
f96c5661ac [Mosaic TPU][NFC] Refactor tpu matmul rule.
* Separate MXU size to MXU contracting size and MXU non-contracting size.
* Rename tile to group for MXU shaped tiling since tile is overused in Mosaic.

PiperOrigin-RevId: 684116306
2024-10-09 11:45:25 -07:00
jax authors
9748e2ab1a [JAX] Fix error message for matmul operand shape check.
PiperOrigin-RevId: 683778484
2024-10-08 15:07:20 -07:00
Adam Paszke
f62941d126 [Mosaic TPU] The previous change does not actually force the input offsets read by the rules, but simply disables all the checks. Reverting so that we at least regain the checks until we have a proper fix.
Reverts 4a596aee1e8920f5b51d5bd573df976390bbd437

PiperOrigin-RevId: 680925509
2024-10-01 02:23:52 -07:00
Jevin Jiang
4a596aee1e [Mosaic TPU] Force offset to 0 when inferring input has offset out of the first tile.
We still have this temporary check in apply vector layout, but in infer vector layout, instead of throwing error, we should just reset offset to zero. Because some ops which has relaxed this restriction might be passed as input for un-relaxed ops and cause failure.

PiperOrigin-RevId: 680706301
2024-09-30 13:52:48 -07:00
Jevin Jiang
7e2f487ada [Mosaic TPU] Canonicalize arith.select's condition to vector if other types are vector.
This fixes the failure in elementwise rule of apply vector layout pass.

If the condition scalar is static, it will be simplified to corresponding vector from true value and false value by MLIR.

If the condition scalar is dynamic, we want to use vselect over scf.if anyway. Because latter creates a inner region.

PiperOrigin-RevId: 680674560
2024-09-30 12:26:44 -07:00
Justin Fu
9f4e8d0039 [XLA:Mosaic][Pallas] Enable vector.ExtractOp for non-zero indices.
PiperOrigin-RevId: 679283281
2024-09-26 13:57:45 -07:00
Jevin Jiang
e4ca4f5a57 Roll back cl/678765762 [Mosaic TPU] Support bitcast without forcing retiling.
Reverts 37641dd4fade625563321b7e1e87165df23cf4a8

PiperOrigin-RevId: 678881199
2024-09-25 16:02:58 -07:00
Jevin Jiang
37641dd4fa [Mosaic TPU] Support bitcast without forcing retiling.
PiperOrigin-RevId: 678765762
2024-09-25 10:57:09 -07:00
Jevin Jiang
407dc774f7 [Mosaic TPU] Support all cases for extui.
PiperOrigin-RevId: 678331795
2024-09-24 11:35:03 -07:00
Jevin Jiang
6b93b35842 [Mosaic:TPU] Efficient relayout with internal scratch
We should support all different retilings (x*packing1, 128) <-> (y*packing2, 128) with any dtype in this cl at this moment. The efficient relayout with scratch brings significant improvements on current retiling in <= TPUv4 and retiling with (packing, 128) in TPUv5. All missing retiling supports are added in this cl, including increase sublane retiling and packed type retiling.

PiperOrigin-RevId: 676982957
2024-09-20 15:00:58 -07:00
Adam Paszke
99195ead83 [Mosaic TPU] Try reducing sublane tiling to support more vector.shape_casts
In particular, 32-bit values should now support all reshapes that do not modify the
last dimension.

PiperOrigin-RevId: 676855401
2024-09-20 08:36:22 -07:00
Jevin Jiang
47b177bd03 [Mosaic TPU][NFC] Remove FailureOr in getNativeVregOrVmaskTypeImpl
PiperOrigin-RevId: 676566796
2024-09-19 14:35:41 -07:00
jax authors
4e6f690724 Merge pull request #23653 from apaszke:torchsaic
PiperOrigin-RevId: 675967844
2024-09-18 06:35:15 -07:00
Adam Paszke
611ad63060 Add basic PyTorch integration for Mosaic GPU
We have already had most of the relevant pieces and we only needed
to connect them together. The most sensitive change is perhaps that
I needed to expose one more symbol from the XLA GPU plugin, but I don't
think it should be a problem.
2024-09-18 12:55:23 +00:00
Jevin Jiang
8d93e101b9 [Mosaic TPU] Propagate the memory space change for memref bitcast and reshape.
PiperOrigin-RevId: 674067380
2024-09-12 17:14:41 -07:00
Jevin Jiang
178fb03050 [Mosaic TPU] Better error message when shape of memref bitcast is invalid.
PiperOrigin-RevId: 674062237
2024-09-12 16:56:50 -07:00
Jevin Jiang
dba674153e [Mosaic TPU] Fix operands order in try canonicalize add of matmul.
PiperOrigin-RevId: 671437100
2024-09-05 11:06:57 -07:00
Adam Paszke
8feab68209 [Mosaic GPU] Remove the unnecessary scratch space operand
And clean up the C++ dispatch code. We don't use HBM scratch anymore
since we pass TMA descriptors as kernel arguments.

PiperOrigin-RevId: 671327420
2024-09-05 04:57:52 -07:00
Jevin Jiang
c1d3c2db9f [Mosaic TPU] Fix mosaic alignment check in concatenate rule.
PiperOrigin-RevId: 670837792
2024-09-03 22:57:27 -07:00
Sergei Lebedev
7dd9adba05 Fixed stack-use-after-scope in Mosaic GPU
PiperOrigin-RevId: 668958750
2024-08-29 09:07:58 -07:00
Jevin Jiang
a3cccd34e2 [Mosaic TPU] Print expected Mosaic version after finding unsupported version.
PiperOrigin-RevId: 668632116
2024-08-28 15:33:31 -07:00
Jevin Jiang
b01075054a [Mosaic TPU] Support memref bitcast.
If element bitwidth changes, the ratio of bitwidth is multiplied to the 2nd minormost dim size and the leading dim in tiling. For example, we can bitcast Memref<8x128xf32> with tiling (8, 128) to Memref<16x128xi16> with tiling (16, 128).

PiperOrigin-RevId: 668619683
2024-08-28 15:00:46 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Adam Paszke
9c3f2dcefc [Mosaic GPU] Make CUDA context part of the hash key + replace kernel id with a SHA256 digest
XLA runtime creates a context per device, so we need to make sure that a kernel is loaded
separately on each device.

PiperOrigin-RevId: 666353098
2024-08-22 08:06:37 -07:00
Benjamin Kramer
0105254ab1 Unbreak Mosaic after 42944da5ba
PiperOrigin-RevId: 665973530
2024-08-21 11:59:09 -07:00
Tomás Longeri
020513f300 [Mosaic] Update serde to handle upstream MLIR changes
For changes from
5f26497da7

PiperOrigin-RevId: 663020509
2024-08-14 12:48:29 -07:00
jax authors
807dcb5a06 Integrate LLVM at llvm/llvm-project@c8b5d30f70
Updates LLVM usage to match
[c8b5d30f7077](https://github.com/llvm/llvm-project/commit/c8b5d30f7077)

PiperOrigin-RevId: 662906261
2024-08-14 07:09:53 -07:00
Jevin Jiang
8f23392a8c [Mosaic:TPU] Refactor relayout helper functions to take ctx instead of only target shape.
PiperOrigin-RevId: 662672417
2024-08-13 15:22:46 -07:00
Jevin Jiang
2dea3d6a0c [Mosaic:TPU] Add shuffled load and store.
we also emulate shuffled store using (store + shuffled load + store) for previous generations.

PiperOrigin-RevId: 662657663
2024-08-13 14:41:16 -07:00
Tomás Longeri
77afe251e7 [Mosaic TPU][Python] Check validity of VectorLayout on init
PiperOrigin-RevId: 661226283
2024-08-09 05:28:00 -07:00
Tomás Longeri
e57a7e3f05 [Mosaic] Column shift relayouts for non-native tilings and packed types, except for (1, n) and packed
PiperOrigin-RevId: 661091012
2024-08-08 20:14:08 -07:00
Adam Paszke
04a753ad02 [Mosaic TPU] Improve an error message in case someone tries to extract a non-32-bit scalar.
PiperOrigin-RevId: 660826696
2024-08-08 07:22:10 -07:00
Adam Paszke
42fe45f34b [Mosaic TPU] Add support for removal of implicit 2nd minor for all 32-bit tilings
PiperOrigin-RevId: 660724215
2024-08-08 01:00:32 -07:00
Sergei Lebedev
28ca734d9b Added another boxDim check to mosaic_gpu_init_tma_desc
PiperOrigin-RevId: 660314586
2024-08-07 03:16:54 -07:00
Adam Paszke
f85b8e677b [Mosaic TPU] Add support for bf16 reductions
PiperOrigin-RevId: 658787017
2024-08-02 07:42:27 -07:00
Adam Paszke
e88887eda5 [Mosaic TPU] Add a missing reshape in relayout
The fact that src generalizes dst does not mean that they have the same implicit
tile shape (if one has an implicit dim and the other one doesn't, then they will
differ by a singleton dimension).

PiperOrigin-RevId: 658775019
2024-08-02 06:44:31 -07:00
Adam Paszke
959657a489 [Mosaic TPU] Remove special handling of implicit dim in relayout
Now all changes happen inside the dedicated functions.

PiperOrigin-RevId: 658763465
2024-08-02 05:46:26 -07:00
Adam Paszke
99625ff577 [Mosaic TPU] Break out implicit dim changes from relayout
PiperOrigin-RevId: 658752228
2024-08-02 04:50:40 -07:00
Adam Paszke
0307438c3d [NFC][Mosaic TPU] Separate out retiling from relayout
PiperOrigin-RevId: 658335679
2024-08-01 03:09:15 -07:00
Adam Paszke
0734345279 [NFC][Mosaic TPU] Start breaking up relayout into smaller pieces
We're constantly hitting unimpelmented relayouts, but it's hard to even know what's
in there given the way the code is written. This is the first of a few clean-up CLs
that aims to partition the process into steps with clear responsibilities. It should
help us better understand what's missing.

PiperOrigin-RevId: 658318811
2024-08-01 02:02:09 -07:00
jax authors
a911d76982 Rollback due to internal test failure
PiperOrigin-RevId: 658185213
2024-07-31 16:40:03 -07:00
Adam Paszke
9dba6eb16a [Mosaic TPU] Add support for 1D windows
PiperOrigin-RevId: 657976726
2024-07-31 05:58:19 -07:00
Adam Paszke
e0415c1865 [Mosaic TPU] Don't fold the accumulator into matmul if it has multiple uses
PiperOrigin-RevId: 657967724
2024-07-31 05:19:52 -07:00
Tomás Longeri
0f834cdf24 [Mosaic TPU] Enable lane broadcast for packed types and offsets outside of first tile, and fix some broadcast infer logic
PiperOrigin-RevId: 656201666
2024-07-25 19:48:20 -07:00
Jevin Jiang
b1b7d0465e [XLA:Mosaic] Support any int type upcast.
Also fixed the int4 unpacking.

PiperOrigin-RevId: 656119043
2024-07-25 15:39:38 -07:00
Sergei Lebedev
5e418f5ab2 Added argument validation to mosaic_gpu_init_tma_desc
This should help with understanding cuTensorMapEncodeTiled failures, since
CUDA doesn't provide any details beyond the error return code.

Note that this change also ensures that TMA descriptors are 64-byte aligned.

PiperOrigin-RevId: 656062820
2024-07-25 13:16:34 -07:00
Tomás Longeri
220ec2aa69 [Mosaic TPU] (8,128),-2 -> (8,128) for non-zero and replicated 2nd minor offset
Also fix bug where relayouts for fully replicated source assumed it was a no-op without checking implicit dims

PiperOrigin-RevId: 655746766
2024-07-24 16:58:35 -07:00
Adam Paszke
dbe8f56353 [Mosaic GPU] Strengthen cluster-related tests by covering more cluster shapes
In particular test trivial collectives (over singleton cluster axes), collectives
over more than 2 devices and clusters larger than 8 devices. This uncovered a few
more bugs in the implementation.

PiperOrigin-RevId: 655686102
2024-07-24 13:43:52 -07:00