187 Commits

Author SHA1 Message Date
jax authors
eeca8d81b9 Fix example in mosaic tpu dialect layout.h
PiperOrigin-RevId: 629424833
2024-04-30 08:42:54 -07:00
Blake Hechtman
5b996f7680 [JAX:MOSAIC] Support transposes that are smaller than the transpose unit and infer native layout to avoid unsupported relayouts.
PiperOrigin-RevId: 629289267
2024-04-29 22:03:32 -07:00
Adam Paszke
32cb7c3f94 [Mosaic GPU] Stop using the MLIR CUDA runtime
This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.

PiperOrigin-RevId: 629069842
2024-04-29 08:04:51 -07:00
jax authors
d9b75350b7 Adds rewrite patterns for arith.{cmpi,select} and tensor.splat as sources to a vector.transfer_read op.
PiperOrigin-RevId: 628561147
2024-04-26 18:11:18 -07:00
jax authors
8c2425e571 Adds rewrite patterns to LinalgVectorizationPass to eliminate transfer_read and transfer_write ops.
PiperOrigin-RevId: 628500668
2024-04-26 13:51:04 -07:00
Adam Paszke
9b0319512a [Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
2024-04-26 09:40:47 -07:00
Adam Paszke
36c471b6f5 [Mosaic] Add support for concatenating arrays of packed types (<32 bits)
PiperOrigin-RevId: 628001232
2024-04-25 02:04:08 -07:00
Adam Paszke
a72a204c39 [Mosaic] Always use 32-bit selects while retiling
Retiling never needs to use packed masks, and those aren't supported on all TPUs.

PiperOrigin-RevId: 627692517
2024-04-24 05:11:58 -07:00
Adam Paszke
5a2d7a2df4 Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.

To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.

There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.

PiperOrigin-RevId: 627670773
2024-04-24 03:27:45 -07:00
Jevin Jiang
167161706c [XLA:Mosaic] Support trunc/ext op for 1D vector with any implicit dim.
PiperOrigin-RevId: 626466602
2024-04-19 14:14:31 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Jevin Jiang
d44b16cfde [XLA:Mosaic] Generalize (8,128) -> (8 * packing,128) retiling for packed type.
PiperOrigin-RevId: 625816937
2024-04-17 15:01:37 -07:00
jax authors
5bd6013e76 [Mosaic] Support scf.while and scf.condition.
This allows lowering while loops of a more general form than "for i" loops.
Improving generality here allows us to implement more interesting dynamic looping behaviors, such as progressive scans in VMEM.

PiperOrigin-RevId: 625411151
2024-04-16 12:07:46 -07:00
Jevin Jiang
e3018dbaa1 [Pallas][Mosaic] Expose semaphore read.
PiperOrigin-RevId: 623593440
2024-04-10 13:45:03 -07:00
Christian Sigg
5d54043336 Switch llo and tpu dialects to MLIR properties.
PiperOrigin-RevId: 622760469
2024-04-08 01:05:20 -07:00
Jevin Jiang
67f4f6032a [XLA:Mosaic] Remove duplicate headers in debug assert insertion.
PiperOrigin-RevId: 619801919
2024-03-27 23:14:05 -07:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Tomás Longeri
7f7e0c00df [Mosaic] Support left shifting relayouts
PiperOrigin-RevId: 618008857
2024-03-21 17:20:30 -07:00
Adam Paszke
7d431ad33b Add support for slicing dynamically-shaped memrefs + DMAs between them
This was a little difficult because our current dialect conversion setup assumes 1-1 type conversions.
I think everything works out fine for as long as we never pass memrefs between basic blocks (i.e.
for as long as we never have memrefs as loop carry or return them from conditionals).

TODO: I still need to make sure that the changes to the TPU dialect are backwards-compatible.
I am afraid that the signature change in MemRefSliceOp might not be.
PiperOrigin-RevId: 617755035
2024-03-21 00:56:51 -07:00
Jevin Jiang
7578e10ce3 [XLA:Mosaic] Support dynamic indices in strided load/store.
PiperOrigin-RevId: 615931990
2024-03-14 16:02:08 -07:00
Jevin Jiang
30208fa9cc [XLA:Mosaic] Support strided load/store memref with arbitrary shape as long as last dim size is 128 and dtype is 32bit.
PiperOrigin-RevId: 614862128
2024-03-11 18:22:11 -07:00
Jevin Jiang
75f2f7510f [XLA:Mosaic] Support input offset (replicated, 0) in shapecast.
PiperOrigin-RevId: 613340933
2024-03-06 14:26:00 -08:00
Jevin Jiang
05f54b665c [XLA:Mosaic] Use different MXU shape based on the target
PiperOrigin-RevId: 612906617
2024-03-05 11:14:24 -08:00
Tomás Longeri
57e34e1a2c [Mosaic][NFC] Use TypedValue<VectorType> instead of Value for applicable arguments/return values in disassemble and relayout
Ideally we would prefer `TypedValue<VectorType>` everywhere possible for static type checking. However, I tried the type for arrays of vregs, `xla::Array<Value>` to `xla::Array<TypedValue<VectorType>>` and ran into issues because MLIR support for arrays/ranges of `TypedValue`s seems lacking.

For example, I can't find a good way to get a `ValueRange` (which many op constructors take) from an array of `TypedValue`s without creating an intermediate vector of `Value`s. Perhaps an unsafe cast if we make the (probably not guaranteed) assumption that `sizeof(TypedValue)` equals `sizeof(Value)`.

Also note that MLIR itself uses untyped `Value`s for ranges of op results and operands even when the op definition declares them to be of a specific type.

PiperOrigin-RevId: 610509743
2024-02-26 13:34:58 -08:00
Tomás Longeri
2f882ad092 [Mosaic][infer-vector-layout] Fix crash with TPU_CHECK_OP
It was using the `op` variable from the `ExtUIOp` above (because variables declared in initializer of an if statement are available in the else branch).

PiperOrigin-RevId: 610481302
2024-02-26 11:58:51 -08:00
Adam Paszke
516b75dc24 Add pl.num_programs to make it easier to query the dynamic grid size
The new function can be used both in the kernel body and in the block specs.

PiperOrigin-RevId: 610391119
2024-02-26 06:39:03 -08:00
Tomás Longeri
61aa7e89aa [Mosaic] Fix bug in divisibility check in infer_vector_layout load and store rules
PiperOrigin-RevId: 609876232
2024-02-23 17:16:53 -08:00
Tomás Longeri
75cdef7626 [Mosaic][NFC] Prefer mlir aliases for llvm classes/functions within mlir namespace for consistency
(also fix a missing cstdint header to fix linter error)

PiperOrigin-RevId: 609826731
2024-02-23 13:48:03 -08:00
Tomás Longeri
8a43140c2e [Mosaic][apply_vector_layout][NFC] Use LLVM_UNLIKELY in TPU_ASSERT_* macros
PiperOrigin-RevId: 609805325
2024-02-23 12:27:49 -08:00
Jevin Jiang
f5c0021071 [XLA:Mosaic] Unify ext/trunc in infer vector layout.
PiperOrigin-RevId: 609765653
2024-02-23 10:19:26 -08:00
Tomás Longeri
c9eaca2282 [Mosaic] In apply_vector_layout, verify VectorLayout invariants that depend on target shape when loading them.
VectorLayout::verify was unused.
PiperOrigin-RevId: 609754730
2024-02-23 09:40:57 -08:00
Jevin Jiang
8d6bb0197b [XLA:Mosaic] Support broadcast scalar with a narrower type.
PiperOrigin-RevId: 609475719
2024-02-22 13:22:17 -08:00
Tomás Longeri
8172c067a4 [Mosaic][NFC] Replace tile_indices variable with tile_offsets with more consistent semantics
The old `tile_indices` variable was misleading and confusing because it sometimes stored indices (in the static case) and sometimes offsets with respect to the tile (in the dynamic case).

PiperOrigin-RevId: 609457122
2024-02-22 12:17:20 -08:00
Jevin Jiang
1fcb84dc90 [XLA:Mosaic] Support broadcast one row with padded tiling.
PiperOrigin-RevId: 609435269
2024-02-22 11:13:28 -08:00
Tomás Longeri
14474acf76 [Mosaic] Fix mistake in error message
PiperOrigin-RevId: 607700109
2024-02-16 08:38:30 -08:00
Tomás Longeri
243e7edc56 [Mosaic] In apply_vector_layout.cc, check layout validity when reading the attribute
This allows us to rely on this throughout the code and replace some checks with TPU_ASSERT_*. They have the semantics of an assert and make it clearer that it is an unexpected internal error (instead of unimplemented or invalid user input that we should handle).

Note: the original error messages for some of these checks were using the wrong input names.
PiperOrigin-RevId: 607463728
2024-02-15 14:51:45 -08:00
Jevin Jiang
a37f2d4a09 [XLA:Mosaic] Prevent generating rotate op if both shift and stride are zeros.
PiperOrigin-RevId: 607433547
2024-02-15 13:19:10 -08:00
Tomás Longeri
72c5aea161 [Mosaic] In apply_vector_layout, prefer returning failure over CHECKs and add macros for this
Rationale is that it's easier to debug from Python.

PiperOrigin-RevId: 607426243
2024-02-15 12:54:33 -08:00
Jevin Jiang
273b5e29de [XLA:Mosaic] Fix dim check in concatenate helper function.
PiperOrigin-RevId: 607405219
2024-02-15 11:53:42 -08:00
Jevin Jiang
50308553e0 [XLA:Mosaic] Expose tpu::RotateOp with stride.
PiperOrigin-RevId: 606772470
2024-02-13 15:51:08 -08:00
Jevin Jiang
9b320f23f0 [XLA:Mosaic] Use join to find a compatible output layout in scf.if.
PiperOrigin-RevId: 605442646
2024-02-08 15:20:26 -08:00
Blake Hechtman
d29c86eb52 [MOSAIC] Accept kernels for a compatibility window with MAXNUMF instead of MAXIMUMF
PiperOrigin-RevId: 605436599
2024-02-08 14:57:35 -08:00
Adam Paszke
0b04ff1241 Add support for non-disjoint windows in Pallas/Mosaic
This enables the index function to select a window starting from
any element. However, the Mosaic implementation still requires it
to be at least tile aligned.

PiperOrigin-RevId: 605254616
2024-02-08 02:48:43 -08:00
Tomás Longeri
62b818c0e5 [Mosaic] Use vector.kind<maximumf> instead of vector.kind<maxnumf>
The semantics are closer to the TPU: having a NaN input results in NaN. However, we don't respect the -0.0 vs +0.0 ordering in older TPUs.

This also fixes a mismatch where we are using `arith.maximumf` for lowering `vector.kind<maxnumf>` (instead of `arith.maximumf` for `vector.kind<maximumf>` or `arith.maxnumf` for `vector.kind<maxnumf>`).

PiperOrigin-RevId: 604849222
2024-02-06 20:24:31 -08:00
jax authors
87d1670db8 [Mosaic] Improves error messages for infer vector layout.
PiperOrigin-RevId: 604697612
2024-02-06 10:47:58 -08:00
Tomás Longeri
ca98ed7c40 [Mosaic] In apply_vector_layout, remove old layout attribute formats
PiperOrigin-RevId: 602723113
2024-01-30 07:44:10 -08:00
Tomás Longeri
1cbda65afc [Mosaic][NFC] Fix comment format
PiperOrigin-RevId: 602416317
2024-01-29 09:47:41 -08:00
Adam Paszke
f625fb69da [Mosaic] Add support for tile-aligned dynamic offsets in loads, stores and ref slices
PiperOrigin-RevId: 597798116
2024-01-12 03:42:58 -08:00
Adam Paszke
8f771b4211 [Mosaic] Simplify the handling of dynamic indices in vector.load and store
This normalizes loads and stores with dynamic base indices into reference
slicing followed by statically indexed loads/stores. This should both simplify
the code (we only have to deal with dynamism in slicing) and improve performance
(we might offset the address once).

PiperOrigin-RevId: 597546106
2024-01-11 07:08:07 -08:00
Adam Paszke
ce00e10d9b [Pallas][Mosaic] Add support for nontrivial semaphore memrefs
The previous patch simply changed the type we use to represent semaphores,
but didn't actually add support for any more operations. With this one,
semaphore memrefs can be allocated and (dynamically) indexed.

PiperOrigin-RevId: 597538913
2024-01-11 06:33:49 -08:00