854 Commits

Author SHA1 Message Date
Kyle Lucke
418b68828a Automated Code Change
PiperOrigin-RevId: 635818645
2024-05-21 08:40:34 -07:00
Tomás Longeri
b197ae527e [Mosaic] Also check bitwidth in apply-vector-layout's layoutIsValidForValue.
PiperOrigin-RevId: 635595321
2024-05-20 15:57:08 -07:00
jax authors
61ff828715 Add support for TPU delay in Mosaic
PiperOrigin-RevId: 635473532
2024-05-20 09:07:56 -07:00
jax authors
974c72b9a1 Merge pull request #21292 from ROCm:rv_stable_051624
PiperOrigin-RevId: 635430659
2024-05-20 05:52:36 -07:00
Ruturaj4
79fccf6c82 add cholesky changes in bazel 2024-05-18 00:37:09 +00:00
jax authors
c4559115ec Internal BUILD file change
PiperOrigin-RevId: 634713068
2024-05-17 04:30:21 -07:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Tomás Longeri
0ad5167da8 Add support for i1 vmasks with packed tiling and 16-bit comparisons (requires hardware support)
PiperOrigin-RevId: 633677477
2024-05-14 12:54:48 -07:00
jax authors
c3cab2e3d3 Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3
PiperOrigin-RevId: 632579101
2024-05-10 12:56:10 -07:00
Peter Hawkins
6c425338d2 Reverts 0267ed0ba9584bbc137792361b53aa80e9c4d306
PiperOrigin-RevId: 632548226
2024-05-10 11:06:38 -07:00
jax authors
0267ed0ba9 Replace xla_extension symlink with genrule that makes xla_extension module accessible from jax._src.lib.
The runfiles of the original targets were lost when the symlinked files were used.

This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved.

PiperOrigin-RevId: 632508121
2024-05-10 08:48:12 -07:00
jax authors
5f702674f7 Merge pull request #21103 from superbobry:mosaic-gpu-fix
PiperOrigin-RevId: 631521771
2024-05-07 13:11:43 -07:00
Sergei Lebedev
8ccbebae4b Fixed Mosaic GPU build following #21029 2024-05-07 17:08:00 +01:00
Sergei Lebedev
51fc4f85ad Ported LuPivotsToPermutation to the typed XLA FFI
The typed FFI

* allows passing custom call attributes directly to backend_config= instead
  of serializing them into a C++ struct.
* It also handles validation and deserialization of custom call operands.

PiperOrigin-RevId: 630067005
2024-05-02 08:12:05 -07:00
Adam Paszke
8692355220 [Mosaic] Add support for remote DMAs and semaphores in megacore mode
The change to tpu.td is not backwards compatible, but I made it so using the
newly added Mosaic stability layer. It's been a good exercise and it seems to
be working just fine.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 630060418
2024-05-02 07:43:36 -07:00
jax authors
e691c19bb2 Merge pull request #21029 from superbobry:jaxlib-mlir-pyi
PiperOrigin-RevId: 629836927
2024-05-01 14:22:21 -07:00
Tomás Longeri
b099eb28a0 [Mosaic] Expand support of vector.extract and vector.extract_strided_slice
- Support non-zero offsets and non-tile-aligned slices for 2D layouts.
- Support vector.extract for non-scalar results.

PiperOrigin-RevId: 629787740
2024-05-01 11:46:02 -07:00
Sergei Lebedev
442526869f Bundle MLIR .pyi files with jaxlib
This allows mypy and pyright to type check the code using MLIR Python APIs.
2024-05-01 19:37:26 +01:00
Tomás Longeri
9bf1148e74 [Mosaic] Always define tiling as (1, 128) for 1D loaded or stored vectors (not for the memref), instead of sometimes using (1, 128 * n).
They are equivalent - the way values are laid out is the same - but relayouts check specifically for (1, 128). We define (1, 128) to be canonical.

PiperOrigin-RevId: 629748121
2024-05-01 09:37:48 -07:00
Sergei Lebedev
e75e4a5991 Do not require a capsule to have a specific name in the CUDA plugin
This aligns the implementation in the plugin with the one in xla_client.

PiperOrigin-RevId: 629724197
2024-05-01 08:00:44 -07:00
Jieying Luo
a949ce772b Add get_device_ordinal to cuda plugin so that CUDA dependency can be removed from py_array (jaxlib).
py_array still has CUDA dependency as a fallback to keep jaxlib[cuda] working before the migration to CUDA plugin.

PiperOrigin-RevId: 629499893
2024-04-30 12:50:50 -07:00
jax authors
eeca8d81b9 Fix example in mosaic tpu dialect layout.h
PiperOrigin-RevId: 629424833
2024-04-30 08:42:54 -07:00
Adam Paszke
4051ac2a2f [Mosaic GPU] Only call kernel initializer from inside a custom call
XLA:GPU custom call design is far from ideal, as there's apparently no way to figure
out the CUDA context that will be used to run an HLO module before the custom call is
first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid
handle errors due to the load and launch happening in different CUDA contexts...

Also fix up build_wheel.py to match the rename of the runtime lib.

PiperOrigin-RevId: 629401858
2024-04-30 07:10:05 -07:00
Blake Hechtman
5b996f7680 [JAX:MOSAIC] Support transposes that are smaller than the transpose unit and infer native layout to avoid unsupported relayouts.
PiperOrigin-RevId: 629289267
2024-04-29 22:03:32 -07:00
Adam Paszke
32cb7c3f94 [Mosaic GPU] Stop using the MLIR CUDA runtime
This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.

PiperOrigin-RevId: 629069842
2024-04-29 08:04:51 -07:00
jax authors
d9b75350b7 Adds rewrite patterns for arith.{cmpi,select} and tensor.splat as sources to a vector.transfer_read op.
PiperOrigin-RevId: 628561147
2024-04-26 18:11:18 -07:00
jax authors
8c2425e571 Adds rewrite patterns to LinalgVectorizationPass to eliminate transfer_read and transfer_write ops.
PiperOrigin-RevId: 628500668
2024-04-26 13:51:04 -07:00
Adam Paszke
9b0319512a [Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
2024-04-26 09:40:47 -07:00
George Necula
d92f4ae157 Reverts 9db5e693ebb4ad786c6e52b562cf32aeaba2e7e1
PiperOrigin-RevId: 628362293
2024-04-26 04:14:34 -07:00
Adam Paszke
ded9272a5b [Mosaic GPU] Implement a simple profilng tool using CUDA events
The other JAX profiling tools are a little heavyweight when we only care about
timing a single kernel programatically.

Also adapt wgmma.py to match failures triggered by upstream MLIR changes.

PiperOrigin-RevId: 628096973
2024-04-25 09:18:39 -07:00
jax authors
6ba4dbade5 Merge pull request #20911 from apaszke:mlir-update
PiperOrigin-RevId: 628040111
2024-04-25 05:14:24 -07:00
jax authors
9db5e693eb Reverts 6bfbb4593a42fced91ba50de47271af425c74c20
PiperOrigin-RevId: 628035616
2024-04-25 04:53:22 -07:00
Adam Paszke
340b9e3739 Update GPU and NVGPU MLIR bindings to match upstream MLIR changes
Upstream MLIR Python bindings now require two more extension libraries
to work properly. The dialects fail to import without this change.
2024-04-25 11:41:19 +00:00
Adam Paszke
36c471b6f5 [Mosaic] Add support for concatenating arrays of packed types (<32 bits)
PiperOrigin-RevId: 628001232
2024-04-25 02:04:08 -07:00
George Necula
6bfbb4593a Remove old ducc_fft custom call.
Starting in June 2023 we have switched the CPU lowering for FFT to use
the new custom call dynamic_ducc_fft. We are now out of the backwards
compatibility window and we remove the old ducc_fft.

We need to keep dynamic_ducc_fft a little bit longer (May 2024).

PiperOrigin-RevId: 627981921
2024-04-25 00:29:11 -07:00
Adam Paszke
a72a204c39 [Mosaic] Always use 32-bit selects while retiling
Retiling never needs to use packed masks, and those aren't supported on all TPUs.

PiperOrigin-RevId: 627692517
2024-04-24 05:11:58 -07:00
Adam Paszke
5a2d7a2df4 Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.

To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.

There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.

PiperOrigin-RevId: 627670773
2024-04-24 03:27:45 -07:00
Jieying Luo
16b4f69769 Rename arg in build script to be more clear.
The flag means skips GPU plugin extension in jaxlib.

PiperOrigin-RevId: 627203738
2024-04-22 17:22:24 -07:00
Jevin Jiang
167161706c [XLA:Mosaic] Support trunc/ext op for 1D vector with any implicit dim.
PiperOrigin-RevId: 626466602
2024-04-19 14:14:31 -07:00
Marvin Kim
90e9e47a55 [Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched.
For configs that cannot be launched, we should not launch them via benchmark.

PiperOrigin-RevId: 626153377
2024-04-18 14:35:51 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Jevin Jiang
d44b16cfde [XLA:Mosaic] Generalize (8,128) -> (8 * packing,128) retiling for packed type.
PiperOrigin-RevId: 625816937
2024-04-17 15:01:37 -07:00
jax authors
5bd6013e76 [Mosaic] Support scf.while and scf.condition.
This allows lowering while loops of a more general form than "for i" loops.
Improving generality here allows us to implement more interesting dynamic looping behaviors, such as progressive scans in VMEM.

PiperOrigin-RevId: 625411151
2024-04-16 12:07:46 -07:00
Jieying Luo
44e83d4e0a Add a few custom call registrations to gpu_kernel to keep in-sync with callers of xla_client.register_custom_call_target.
PiperOrigin-RevId: 624275186
2024-04-12 13:30:18 -07:00
Jevin Jiang
e3018dbaa1 [Pallas][Mosaic] Expose semaphore read.
PiperOrigin-RevId: 623593440
2024-04-10 13:45:03 -07:00
Henning Becker
9809aa1929 Move CUDA specific functions from asm_compiler to cuda_asm_compiler target
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target

The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.

Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.

PiperOrigin-RevId: 623285804
2024-04-09 14:43:41 -07:00
David Dunleavy
cd2b91c398 Update references to TSL config_settings to their new home in XLA
PiperOrigin-RevId: 623249851
2024-04-09 12:36:10 -07:00
David Dunleavy
d18323f3c4 Move tsl/BUILD, tsl.bzl, and tsl.default.bzl to XLA
PiperOrigin-RevId: 623215553
2024-04-09 10:47:06 -07:00
jax authors
f5cc272615 Merge pull request #20646 from ROCm:rcom-ci-tsl-path-fix
PiperOrigin-RevId: 623129753
2024-04-09 05:09:11 -07:00