225 Commits

Author SHA1 Message Date
jax authors
a0e5e0f411 Integrate LLVM at llvm/llvm-project@c012e487b7
Updates LLVM usage to match
[c012e487b724](https://github.com/llvm/llvm-project/commit/c012e487b724)

PiperOrigin-RevId: 642581785
2024-06-12 05:11:10 -07:00
Tomás Longeri
3e1e98992c [Mosaic] Handle adding singleton minor dimension that was already implicit for non-32-bit types, and do not force native tiling
Also fix extra comma in apply_vector_layout_test which was being annoying with autoformatter

PiperOrigin-RevId: 642454594
2024-06-11 18:10:26 -07:00
jax authors
d20b9e324f Integrate LLVM at llvm/llvm-project@8c5d9c79b9
Updates LLVM usage to match
[8c5d9c79b96e](https://github.com/llvm/llvm-project/commit/8c5d9c79b96e)

PiperOrigin-RevId: 642352474
2024-06-11 12:24:43 -07:00
Jevin Jiang
5b38549810 [XLA:Mosaic] No need to assume a multiple of tile if tile dim size is 1.
PiperOrigin-RevId: 642301822
2024-06-11 09:53:13 -07:00
Adam Paszke
1256ceb266 [Mosaic GPU] Rearrange the pass pipeline (again)
PiperOrigin-RevId: 642256145
2024-06-11 06:59:50 -07:00
jax authors
71c19b779d Rewrite vector.contraction with bf16 accumulator and output into a
contraction with f32 accumulator and output, where the accumulator is
extended and the output truncated. For targets that do not support bf16
matmul, the lhs and rhs are extended to f32.

PiperOrigin-RevId: 642051952
2024-06-10 16:02:46 -07:00
Jevin Jiang
53daa0c742 [XLA:Mosaic] Fix infer layout for nested loop.
- We should recursively clear layouts and any assume_layout ops if we want to override layouts in a block.
- Refactor the logic of assume layouts for block arguments to a helper function.
- Add tests for nested fori loop and while loop.

PiperOrigin-RevId: 641973011
2024-06-10 11:49:01 -07:00
Adam Paszke
0739d520b1 [Mosaic GPU] Don't always run with llvm::DebugFlag enabled
This slipped past during code review.

PiperOrigin-RevId: 641899993
2024-06-10 07:50:26 -07:00
Adam Paszke
3b4039c850 [Mosaic GPU] Load LLVM lowering interfaces for all dialects
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.

PiperOrigin-RevId: 641874183
2024-06-10 05:55:01 -07:00
jax authors
f51af87fc5 fp8 matmul in pallas
PiperOrigin-RevId: 641254832
2024-06-07 08:17:06 -07:00
Tomás Longeri
a65d3ae0da [Mosaic] Expand vector.shape_cast support for sublane (un)folding no-ops
- Support non-zero minor offsets without having to relayout (they're still a no-op).
- Remove restriction on tiling which now allows 1D packed types to work.

PiperOrigin-RevId: 640967375
2024-06-06 11:35:19 -07:00
Tomás Longeri
20d9aac6be [Mosaic] Remove some restrictions for vector.shape_cast in infer-vector-layout and apply-vector-layout
- On infer-vector-layout remove some restrictions related to batch dimensions. Reshaping them doesn't matter as long as they don't combine with tiled dimensions.
- On apply-vector-layout, simplify handling of cases where the implicit tiled don't change while removing some unnecessary restrictions.
  - Don't require native tiling or natural topology for this.

PiperOrigin-RevId: 640837740
2024-06-06 03:26:43 -07:00
Adam Paszke
6a1fcc6cb2 [Mosaic TPU] Normalize inferred layouts to supported ones in matmul rule
Previously the rule would complain if the layouts were unsupported, but that's not
the right way to handle that situation. With this change, we simply pick a supported
configuration instead (and expect relayout to handle it).

PiperOrigin-RevId: 640190248
2024-06-04 10:03:00 -07:00
Tomás Longeri
8a1445a038 [Mosaic] Document tpu.create_subelement_mask
PiperOrigin-RevId: 639898224
2024-06-03 13:45:09 -07:00
Tomás Longeri
e620acfa17 [Mosaic] Remove "support" for MAXNUMF vector reductions
Pallas hasn't been using it to lower since cl/604849222, which is before we had serde, so we won't create a serde rule for this. That CL also tried to remove the support, but it had to be restored in cl/605436599 because it was breaking serialized kernels.

PiperOrigin-RevId: 639896981
2024-06-03 13:40:31 -07:00
jax authors
1edb94ec46 [XLA][Mosaic] Add support for fp8 matmuls in TPUv5+
Needed a little more backfill for TPU load

PiperOrigin-RevId: 639206243
2024-05-31 17:51:25 -07:00
Jevin Jiang
389bf93abf [XLA:Mosaic] Fix infer/apply vector layout rule for terminators (scf::yieldOp, scf::conditionOp).
We should infer layout for each terminator inside its own region and find a compatible layout for a final result if the result is based on terminators from multiple regions like scf::ifOp, scf::whileOp, scf::forOp. If no compatible layout is found, we will fall back to a normalized layout. Finally we also need to ensure the layouts in input, terminator and output are consistent across loops.

PiperOrigin-RevId: 639122434
2024-05-31 12:47:33 -07:00
Adam Paszke
d01496a09a [Mosaic GPU] Restore the PTX/PTXAS/SASS dump flags
They're very useful while prototyping the kernels.

PiperOrigin-RevId: 639027506
2024-05-31 07:27:36 -07:00
Sergei Lebedev
d2a39bc61b Updated the layer norm implementation in Mosaic GPU tests
jnp.var now needs lax.gt_p, which we don't currently support.

PiperOrigin-RevId: 639011383
2024-05-31 06:11:48 -07:00
Sergei Lebedev
8729952d82 Added a missing return to MosaicGPUCustomCall
PiperOrigin-RevId: 638627696
2024-05-30 06:13:01 -07:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Tomás Longeri
8f8b976421 [Mosaic] Packed loads and stores with 1D tiling should use (1, 128 * packing)
There are multiple representations for 1D tiling in vector layouts and we need to choose one consistently.

PiperOrigin-RevId: 638331061
2024-05-29 10:25:07 -07:00
Tomás Longeri
a07c7816ab [Mosaic] Fix bug in VectorLayout::generalizes introduced in cl/636250759
PiperOrigin-RevId: 638253907
2024-05-29 05:48:19 -07:00
Tomás Longeri
8b95853609 [Mosaic] Add relayout for (1, 128 * packing) -> (packing, 128).
PiperOrigin-RevId: 637951690
2024-05-28 10:47:41 -07:00
Tomás Longeri
97f9a5e80e [Mosaic] Expand vector.shape_cast no-op detection for expanding/shrinking lane shape casts
- Remove restriction on sublane tiling being 1 or a multiple of 8 on the expanded shape.
- Support packed types.

PiperOrigin-RevId: 637777493
2024-05-27 22:32:08 -07:00
Tomás Longeri
3fb9acf01a [Mosaic] Expand support of vector.broadcast
- Enable it for minor or second-minor implicit dims for the non-no-op case.
- Don't allow output offsets for broadcasted dimensions to be non-replicated. Make sure to assign them as replicated in infer-vector-layout for all cases.
- Don't fail when both tiled dimensions are logically broadcasted but only one of them requires actual broadcasting (before, it would hit the unimplemented sublane + lane broadcast case).

PiperOrigin-RevId: 637772134
2024-05-27 21:59:37 -07:00
Justin Fu
683ca2cd40 [Pallas][Mosaic] Add lowering rules for PRNG ops.
PiperOrigin-RevId: 636999151
2024-05-24 12:22:15 -07:00
Sergei Lebedev
daa81e6fb5 Added support for printing scalar values in Pallas TPU kernels
The implementation uses the new tpu.log operation in the Mosaic TPU dialect.

Note that

* the logging only happens if --xla_tpu_enable_log_recorder is set;
* only scalars can be printed;
* placeholders only accept i32 arguments at the moment.

PiperOrigin-RevId: 636585852
2024-05-23 10:02:00 -07:00
Adam Paszke
63a13f516d [Mosaic TPU] Add support for tpu.iota over untiled dimensions
PiperOrigin-RevId: 636567090
2024-05-23 08:56:54 -07:00
Tomás Longeri
5ae2491853 [Mosaic] Use implicit shape over directly using implicit_dim() in layout.{cc, h}
This CL also fixes a bug in `VectorLayout::join` where `ImplicitDim::kMinor` was considered equivalent to `ImplicitDim::kNone` when the shape's minor dimension is 1 (also needed to check that the second-minor dimension is 1).

Often handling every implicit dim case separately is more complex and error-prone.

There will be more follow-up changes to do this consistently elsewhere the code. This is also a first step towards 0D layout support.

PiperOrigin-RevId: 636250759
2024-05-22 12:11:28 -07:00
Jevin Jiang
ccaf466a60 [XLA:Mosaic] Support retiling from (8, 128, -2) to (8, 128) for 32-bit data.
This drops implicit second minor dim and fixes the infer vector layout for concatenate rule when concatenating constant vectors.

PiperOrigin-RevId: 635917408
2024-05-21 13:52:26 -07:00
jax authors
92d892b425 Adds rewrite patterns for arith and math operations with bf16 operands/results that are not supported by the underlying hardware.
PiperOrigin-RevId: 635865752
2024-05-21 11:08:19 -07:00
Tomás Longeri
b197ae527e [Mosaic] Also check bitwidth in apply-vector-layout's layoutIsValidForValue.
PiperOrigin-RevId: 635595321
2024-05-20 15:57:08 -07:00
jax authors
61ff828715 Add support for TPU delay in Mosaic
PiperOrigin-RevId: 635473532
2024-05-20 09:07:56 -07:00
Tomás Longeri
0ad5167da8 Add support for i1 vmasks with packed tiling and 16-bit comparisons (requires hardware support)
PiperOrigin-RevId: 633677477
2024-05-14 12:54:48 -07:00
Adam Paszke
8692355220 [Mosaic] Add support for remote DMAs and semaphores in megacore mode
The change to tpu.td is not backwards compatible, but I made it so using the
newly added Mosaic stability layer. It's been a good exercise and it seems to
be working just fine.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 630060418
2024-05-02 07:43:36 -07:00
Tomás Longeri
b099eb28a0 [Mosaic] Expand support of vector.extract and vector.extract_strided_slice
- Support non-zero offsets and non-tile-aligned slices for 2D layouts.
- Support vector.extract for non-scalar results.

PiperOrigin-RevId: 629787740
2024-05-01 11:46:02 -07:00
Tomás Longeri
9bf1148e74 [Mosaic] Always define tiling as (1, 128) for 1D loaded or stored vectors (not for the memref), instead of sometimes using (1, 128 * n).
They are equivalent - the way values are laid out is the same - but relayouts check specifically for (1, 128). We define (1, 128) to be canonical.

PiperOrigin-RevId: 629748121
2024-05-01 09:37:48 -07:00
jax authors
eeca8d81b9 Fix example in mosaic tpu dialect layout.h
PiperOrigin-RevId: 629424833
2024-04-30 08:42:54 -07:00
Blake Hechtman
5b996f7680 [JAX:MOSAIC] Support transposes that are smaller than the transpose unit and infer native layout to avoid unsupported relayouts.
PiperOrigin-RevId: 629289267
2024-04-29 22:03:32 -07:00
Adam Paszke
32cb7c3f94 [Mosaic GPU] Stop using the MLIR CUDA runtime
This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.

PiperOrigin-RevId: 629069842
2024-04-29 08:04:51 -07:00
jax authors
d9b75350b7 Adds rewrite patterns for arith.{cmpi,select} and tensor.splat as sources to a vector.transfer_read op.
PiperOrigin-RevId: 628561147
2024-04-26 18:11:18 -07:00
jax authors
8c2425e571 Adds rewrite patterns to LinalgVectorizationPass to eliminate transfer_read and transfer_write ops.
PiperOrigin-RevId: 628500668
2024-04-26 13:51:04 -07:00
Adam Paszke
9b0319512a [Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
2024-04-26 09:40:47 -07:00
Adam Paszke
36c471b6f5 [Mosaic] Add support for concatenating arrays of packed types (<32 bits)
PiperOrigin-RevId: 628001232
2024-04-25 02:04:08 -07:00
Adam Paszke
a72a204c39 [Mosaic] Always use 32-bit selects while retiling
Retiling never needs to use packed masks, and those aren't supported on all TPUs.

PiperOrigin-RevId: 627692517
2024-04-24 05:11:58 -07:00
Adam Paszke
5a2d7a2df4 Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.

To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.

There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.

PiperOrigin-RevId: 627670773
2024-04-24 03:27:45 -07:00
Jevin Jiang
167161706c [XLA:Mosaic] Support trunc/ext op for 1D vector with any implicit dim.
PiperOrigin-RevId: 626466602
2024-04-19 14:14:31 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Jevin Jiang
d44b16cfde [XLA:Mosaic] Generalize (8,128) -> (8 * packing,128) retiling for packed type.
PiperOrigin-RevId: 625816937
2024-04-17 15:01:37 -07:00