279 Commits

Author SHA1 Message Date
Sergei Lebedev
28ca734d9b Added another boxDim check to mosaic_gpu_init_tma_desc
PiperOrigin-RevId: 660314586
2024-08-07 03:16:54 -07:00
Adam Paszke
f85b8e677b [Mosaic TPU] Add support for bf16 reductions
PiperOrigin-RevId: 658787017
2024-08-02 07:42:27 -07:00
Adam Paszke
e88887eda5 [Mosaic TPU] Add a missing reshape in relayout
The fact that src generalizes dst does not mean that they have the same implicit
tile shape (if one has an implicit dim and the other one doesn't, then they will
differ by a singleton dimension).

PiperOrigin-RevId: 658775019
2024-08-02 06:44:31 -07:00
Adam Paszke
959657a489 [Mosaic TPU] Remove special handling of implicit dim in relayout
Now all changes happen inside the dedicated functions.

PiperOrigin-RevId: 658763465
2024-08-02 05:46:26 -07:00
Adam Paszke
99625ff577 [Mosaic TPU] Break out implicit dim changes from relayout
PiperOrigin-RevId: 658752228
2024-08-02 04:50:40 -07:00
Adam Paszke
0307438c3d [NFC][Mosaic TPU] Separate out retiling from relayout
PiperOrigin-RevId: 658335679
2024-08-01 03:09:15 -07:00
Adam Paszke
0734345279 [NFC][Mosaic TPU] Start breaking up relayout into smaller pieces
We're constantly hitting unimpelmented relayouts, but it's hard to even know what's
in there given the way the code is written. This is the first of a few clean-up CLs
that aims to partition the process into steps with clear responsibilities. It should
help us better understand what's missing.

PiperOrigin-RevId: 658318811
2024-08-01 02:02:09 -07:00
jax authors
a911d76982 Rollback due to internal test failure
PiperOrigin-RevId: 658185213
2024-07-31 16:40:03 -07:00
Adam Paszke
9dba6eb16a [Mosaic TPU] Add support for 1D windows
PiperOrigin-RevId: 657976726
2024-07-31 05:58:19 -07:00
Adam Paszke
e0415c1865 [Mosaic TPU] Don't fold the accumulator into matmul if it has multiple uses
PiperOrigin-RevId: 657967724
2024-07-31 05:19:52 -07:00
Tomás Longeri
0f834cdf24 [Mosaic TPU] Enable lane broadcast for packed types and offsets outside of first tile, and fix some broadcast infer logic
PiperOrigin-RevId: 656201666
2024-07-25 19:48:20 -07:00
Jevin Jiang
b1b7d0465e [XLA:Mosaic] Support any int type upcast.
Also fixed the int4 unpacking.

PiperOrigin-RevId: 656119043
2024-07-25 15:39:38 -07:00
Sergei Lebedev
5e418f5ab2 Added argument validation to mosaic_gpu_init_tma_desc
This should help with understanding cuTensorMapEncodeTiled failures, since
CUDA doesn't provide any details beyond the error return code.

Note that this change also ensures that TMA descriptors are 64-byte aligned.

PiperOrigin-RevId: 656062820
2024-07-25 13:16:34 -07:00
Tomás Longeri
220ec2aa69 [Mosaic TPU] (8,128),-2 -> (8,128) for non-zero and replicated 2nd minor offset
Also fix bug where relayouts for fully replicated source assumed it was a no-op without checking implicit dims

PiperOrigin-RevId: 655746766
2024-07-24 16:58:35 -07:00
Adam Paszke
dbe8f56353 [Mosaic GPU] Strengthen cluster-related tests by covering more cluster shapes
In particular test trivial collectives (over singleton cluster axes), collectives
over more than 2 devices and clusters larger than 8 devices. This uncovered a few
more bugs in the implementation.

PiperOrigin-RevId: 655686102
2024-07-24 13:43:52 -07:00
Sharad Vikram
cfd9d8f548 [Pallas/TPU] Allow reading DMA semaphores in Pallas
PiperOrigin-RevId: 655384701
2024-07-23 19:08:45 -07:00
Jevin Jiang
59e944dadf [XLA:Mosaic] Pass rewrite ctx of apply-vector-layout pass to relayout function.
We will implement a more efficient relayout according to the configs in rewrite ctx, such as `hardware_generation`, `max_sublanes_in_scratch` and so on. So it makes sense to change the relayout interface to take ctx (including python bindings). Now we can define rewrite ctx in `apply_vector_layout_test` as well. It makes it easier to test some advanced stuff (eg., mxu_shape change, max_sublanes_in_scratch change for rotate and relayout).

PiperOrigin-RevId: 655350013
2024-07-23 16:50:45 -07:00
Jevin Jiang
7e2107b1ee [XLA:Mosaic] Create apply layout pass with ctx instead of config list.
This cl removes the funcOp from RewriteContext of apply-vector-layout-pass (since only one function is using it) and uses context to create the pass instead of a long list of arguments. We will need to add more args (target's bank counts) to create apply-vector-layout.

PiperOrigin-RevId: 655329321
2024-07-23 15:43:26 -07:00
Adam Paszke
a2b2fbf513 [Mosaic GPU] Add early support for block clusters and multicast TMA
PiperOrigin-RevId: 655057490
2024-07-23 00:50:20 -07:00
Tomás Longeri
d350ef779c [Mosaic TPU][apply-vector-layout] Do not broadcast in copy_one_sublane
This affects the (packing, 128) -> (8 * packing, 128) and 32-bit (8, 128),-2 -> (8, 128) retilings:
- No longer always broadcast the first sublane of a vreg before blending, which is usually unnecessary. Rotate instead, unless dst requires replicated offsets in (1, 128) -> (8, 128).
  For (8, 128),-2 -> (8, 128), with our current restrictions, the first vreg always already has the sublane in the right position, so the broadcast is always wasteful.
- Unclear if rotate is always better than broadcast, but it doesn't make sense to broadcast the first vreg yet rotate the others.

This is some cleanup prior to removing some offset restrictions for (8, 128),-2 -> (8, 128)

PiperOrigin-RevId: 654935883
2024-07-22 16:31:27 -07:00
Tomás Longeri
5f18a2e27b [Mosaic TPU] Enable (packing, 128) -> (8 * packing, 128) retiling
PiperOrigin-RevId: 654922099
2024-07-22 15:47:21 -07:00
Tomás Longeri
bf42564172 [Mosaic TPU][NFC] Remove unused toArrayRef for std::array
It's unused, buggy (will return a reference to local copy of array) and `ArrayRef` already has a ctor that takes a `std::array`

PiperOrigin-RevId: 654916697
2024-07-22 15:29:06 -07:00
Tomás Longeri
f4b09234a0 [Mosaic TPU] Set in_bounds for transfer_read used in replicated loads
This is in preparation for integrating changes from MLIR:
2ee5586ac7 (diff-3cbcc8f6c740f2d6e16f5a0c19daf4bb8224ad92d9e430fc10c935587a67dcce)

Also don't pass in `padding` since there is a builder that uses `padding` of zero as default.

PiperOrigin-RevId: 654370142
2024-07-20 16:26:18 -07:00
Jevin Jiang
faf89ab0da [XLA:Mosaic] Simplify the logic in converting dynamic roll to Log(N) static ops.
PiperOrigin-RevId: 654065156
2024-07-19 11:11:22 -07:00
Tomás Longeri
a9772494b2 [Mosaic] Simplify vector.shape_cast rules and cover more cases
- Sublane unfolding was not being checked for non-empty implicit dims e.g. (2, 2, 128, 1) -> (2, 256) would not work
- Noop squeeze/unsqueeze paths in infer-vector-layout, when the source has ImplicitDim::kNone, were forcing native tiling for some reason
- 1D lane squeeze was always assigning bitwidth of 32.
- Maybe others

PiperOrigin-RevId: 653910942
2024-07-19 00:55:48 -07:00
jax authors
378a830322 Add support for multi row shift.
PiperOrigin-RevId: 653395441
2024-07-17 16:19:14 -07:00
Jevin Jiang
63a3e6736c [XLA:Mosaic] Extend support of tpu bitcast with offsets and implicit dim.
* if bitwidth does not change after bitcast:
  - We can bitcast the input with any vector layout.
* if bitwidth changes after bitcast:
  - We can bitcast the input with sublane offset which is a multiple of the ratio of bandwidths.

PiperOrigin-RevId: 653375579
2024-07-17 15:10:27 -07:00
Adam Paszke
a335839ab8 [Mosaic TPU] Update transpose unrolling for new TPUs
PiperOrigin-RevId: 653348218
2024-07-17 13:44:01 -07:00
Tomás Longeri
da02ba196e [Mosaic] Most relayouts should work for any matched implicit dim, or on mismatched but equivalent ones
Also fix bug in (1, 128 * packing) -> (packing, 128) retiling where the part index could be incremented OOB.

Note: Many relayouts might be inefficient for implicit dims. If, for example, implicit dim is kSecondMinor, retiling might blend tiles that are only padding. This also applies to kNone implicit dim with small shapes, however, so any optimizations should be written based on the implicit shape.
PiperOrigin-RevId: 653209744
2024-07-17 06:17:32 -07:00
jax authors
764ec92118 Add support for elementwise op canonicalization in fp32 for older hardware.
PiperOrigin-RevId: 651959463
2024-07-12 19:58:55 -07:00
Jevin Jiang
aa16485457 [XLA:Mosaic] Support memref shapecast.
This cl supports memref shapecast:
1. if tile is (1, 128), we support shapecast on any dim.
2. if shapecast on sublane dim, we only support tile aligned shape.
3. if shapecast on non-tiling dim, we support any shapecast.
4. all other cases would be considered as invalid memref shapecast.

PiperOrigin-RevId: 651924552
2024-07-12 17:05:03 -07:00
Sharad Vikram
7016ca4829 [Mosaic] Strengthen check on return types from RegionOp
PiperOrigin-RevId: 651879359
2024-07-12 13:59:50 -07:00
Sharad Vikram
2cbe6caa50 [Pallas/Mosaic] Add support for returning values from run_scoped
PiperOrigin-RevId: 651600628
2024-07-11 18:37:09 -07:00
jax authors
3ee3c2a1cc Add support, via conversion in canonicalization pass, for mixed int/float matmul.
PiperOrigin-RevId: 651141070
2024-07-10 13:54:18 -07:00
Justin Fu
0cb82cea65 [Pallas] Add better reduction support.
Adds lowering rules for reduce_all, reduce_any, reduce_min, and reductions to scalars.

PiperOrigin-RevId: 650689871
2024-07-09 11:03:17 -07:00
jax authors
0da9b69285 Use default tiling in scratch buffers if XLA enables it
PiperOrigin-RevId: 650493683
2024-07-08 22:49:10 -07:00
Tomás Longeri
5c7c29bc6e [Mosaic] Remove restriction of offsets falling in first tile of vreg, start rolling out op support for it, starting with vector.extract_strided_slice
VectorLayout offsets are now allowed to fall anywhere within the vreg slice. This way, tiling is still applied after offsets and offsets are still applied after implicit dimensions.
Note that offsets outside of the vreg slice would mean a vreg full of padding, which is why we disallow them.

PiperOrigin-RevId: 650408597
2024-07-08 16:23:10 -07:00
jax authors
2561ba5d37 Introduce a canonicalize pass, rewrite all contractions as matmuls (vector::ContractionOp as tpu::MatMulOp), remove special handling for contraction op in other passes.
PiperOrigin-RevId: 649205635
2024-07-03 14:49:35 -07:00
Jevin Jiang
484d09f4af [Pallas][Mosaic] Relax dynamic index on 2nd minor dim in load/store.
We support any dynamic index on 2nd minor dim in either of the cases:
1. The minormost dim size of a unsliced memref matches VREG lane count.
2. Load/store one row on the second minormost dim, which triggers implicit strided load/store.

Note: For the default cases which can not skip the alignment check, we still use dynamic slice + static load/store solution to reduce scalar core work. We should figure out a way to optimize this in all cases.
PiperOrigin-RevId: 648771794
2024-07-02 10:52:11 -07:00
Adam Paszke
265a54da31 [Mosaic GPU] Pass in TMA descriptors through kernel parameters
As we've established (sigh) we can't pass in TMA descriptors through global memory.
The current workaround was to use constant memory instead, but this raises a number of
potential concurrency issues. So, instead, we use the freshly added support for grid_constant
parameters in upstream LLVM to pass the descriptors as kernel arguments. This seems to work
fine and should in fact have lower overheads than both previous methods.

PiperOrigin-RevId: 648744363
2024-07-02 09:30:52 -07:00
Tomás Longeri
3a21c81eac [Mosaic] Fix bug in VectorLayout::generalizes after cl/647395486
PiperOrigin-RevId: 647603239
2024-06-28 02:14:04 -07:00
Tomás Longeri
10e598a3fc [Mosaic] In VectorLayout::generalizes, for (1, n) tiling, we can always squeeze out a 2nd minor dimension
PiperOrigin-RevId: 647395486
2024-06-27 11:52:02 -07:00
Christos Perivolaropoulos
ea49194926 [msoaic_gpu] Control dumping mlir with MOSAIC_GPU_DUMP_MLIR_PASSES
PiperOrigin-RevId: 647341364
2024-06-27 09:17:52 -07:00
jax authors
9df105c18f Pass the assigned layout to infer_memref_layout for correct memref
layout.

PiperOrigin-RevId: 647323218
2024-06-27 08:16:00 -07:00
Tomás Longeri
94c5d0d747 [Mosaic][apply-vector-layout] Fix possible segfault in arith.extsi/arith.extf after cl/644495447
This only happens for layout pairs that are never inferred.

PiperOrigin-RevId: 646303509
2024-06-24 19:54:08 -07:00
Tomás Longeri
21bf3d196d [Mosaic][Python] Define __repr__ for VectorLayout
Loosely follows the example MLIR's bindings for Attribute

PiperOrigin-RevId: 646270865
2024-06-24 17:18:15 -07:00
Tomás Longeri
a730f6bfd3 [Mosaic][infer-vector-layout] Allow non-32-bit types for vector.extract_strided_slice
PiperOrigin-RevId: 645481424
2024-06-21 13:17:37 -07:00
Kyle Lucke
84d748f43c Stop using xla/statusor.h now that it just contains an alias for absl::Status.
In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free.

PiperOrigin-RevId: 645169743
2024-06-20 15:09:40 -07:00
Chris Jones
de8fd3b00d [mosaic:gpu] Fix MLIR canonicalization pass region-simplify option.
`region-simplify` now has `normal` and `aggressive` modes (using `normal` for now).

PiperOrigin-RevId: 644724434
2024-06-19 06:02:11 -07:00
Jevin Jiang
cac1791f7c [XLA:Mosaic] Support dynamic roll
We will choose the best solution based on the size of internal scratch memory.
- Sol 1: Convert dynamic roll to Log(N) static ops
- Sol 2: Static Store + Dynamic Load with internal scratch

PiperOrigin-RevId: 644509328
2024-06-18 14:18:56 -07:00