832 Commits

Author SHA1 Message Date
jax authors
eeca8d81b9 Fix example in mosaic tpu dialect layout.h
PiperOrigin-RevId: 629424833
2024-04-30 08:42:54 -07:00
Adam Paszke
4051ac2a2f [Mosaic GPU] Only call kernel initializer from inside a custom call
XLA:GPU custom call design is far from ideal, as there's apparently no way to figure
out the CUDA context that will be used to run an HLO module before the custom call is
first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid
handle errors due to the load and launch happening in different CUDA contexts...

Also fix up build_wheel.py to match the rename of the runtime lib.

PiperOrigin-RevId: 629401858
2024-04-30 07:10:05 -07:00
Blake Hechtman
5b996f7680 [JAX:MOSAIC] Support transposes that are smaller than the transpose unit and infer native layout to avoid unsupported relayouts.
PiperOrigin-RevId: 629289267
2024-04-29 22:03:32 -07:00
Adam Paszke
32cb7c3f94 [Mosaic GPU] Stop using the MLIR CUDA runtime
This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.

PiperOrigin-RevId: 629069842
2024-04-29 08:04:51 -07:00
jax authors
d9b75350b7 Adds rewrite patterns for arith.{cmpi,select} and tensor.splat as sources to a vector.transfer_read op.
PiperOrigin-RevId: 628561147
2024-04-26 18:11:18 -07:00
jax authors
8c2425e571 Adds rewrite patterns to LinalgVectorizationPass to eliminate transfer_read and transfer_write ops.
PiperOrigin-RevId: 628500668
2024-04-26 13:51:04 -07:00
Adam Paszke
9b0319512a [Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
2024-04-26 09:40:47 -07:00
George Necula
d92f4ae157 Reverts 9db5e693ebb4ad786c6e52b562cf32aeaba2e7e1
PiperOrigin-RevId: 628362293
2024-04-26 04:14:34 -07:00
Adam Paszke
ded9272a5b [Mosaic GPU] Implement a simple profilng tool using CUDA events
The other JAX profiling tools are a little heavyweight when we only care about
timing a single kernel programatically.

Also adapt wgmma.py to match failures triggered by upstream MLIR changes.

PiperOrigin-RevId: 628096973
2024-04-25 09:18:39 -07:00
jax authors
6ba4dbade5 Merge pull request #20911 from apaszke:mlir-update
PiperOrigin-RevId: 628040111
2024-04-25 05:14:24 -07:00
jax authors
9db5e693eb Reverts 6bfbb4593a42fced91ba50de47271af425c74c20
PiperOrigin-RevId: 628035616
2024-04-25 04:53:22 -07:00
Adam Paszke
340b9e3739 Update GPU and NVGPU MLIR bindings to match upstream MLIR changes
Upstream MLIR Python bindings now require two more extension libraries
to work properly. The dialects fail to import without this change.
2024-04-25 11:41:19 +00:00
Adam Paszke
36c471b6f5 [Mosaic] Add support for concatenating arrays of packed types (<32 bits)
PiperOrigin-RevId: 628001232
2024-04-25 02:04:08 -07:00
George Necula
6bfbb4593a Remove old ducc_fft custom call.
Starting in June 2023 we have switched the CPU lowering for FFT to use
the new custom call dynamic_ducc_fft. We are now out of the backwards
compatibility window and we remove the old ducc_fft.

We need to keep dynamic_ducc_fft a little bit longer (May 2024).

PiperOrigin-RevId: 627981921
2024-04-25 00:29:11 -07:00
Adam Paszke
a72a204c39 [Mosaic] Always use 32-bit selects while retiling
Retiling never needs to use packed masks, and those aren't supported on all TPUs.

PiperOrigin-RevId: 627692517
2024-04-24 05:11:58 -07:00
Adam Paszke
5a2d7a2df4 Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.

To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.

There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.

PiperOrigin-RevId: 627670773
2024-04-24 03:27:45 -07:00
Jieying Luo
16b4f69769 Rename arg in build script to be more clear.
The flag means skips GPU plugin extension in jaxlib.

PiperOrigin-RevId: 627203738
2024-04-22 17:22:24 -07:00
Jevin Jiang
167161706c [XLA:Mosaic] Support trunc/ext op for 1D vector with any implicit dim.
PiperOrigin-RevId: 626466602
2024-04-19 14:14:31 -07:00
Marvin Kim
90e9e47a55 [Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched.
For configs that cannot be launched, we should not launch them via benchmark.

PiperOrigin-RevId: 626153377
2024-04-18 14:35:51 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Jevin Jiang
d44b16cfde [XLA:Mosaic] Generalize (8,128) -> (8 * packing,128) retiling for packed type.
PiperOrigin-RevId: 625816937
2024-04-17 15:01:37 -07:00
jax authors
5bd6013e76 [Mosaic] Support scf.while and scf.condition.
This allows lowering while loops of a more general form than "for i" loops.
Improving generality here allows us to implement more interesting dynamic looping behaviors, such as progressive scans in VMEM.

PiperOrigin-RevId: 625411151
2024-04-16 12:07:46 -07:00
Jieying Luo
44e83d4e0a Add a few custom call registrations to gpu_kernel to keep in-sync with callers of xla_client.register_custom_call_target.
PiperOrigin-RevId: 624275186
2024-04-12 13:30:18 -07:00
Jevin Jiang
e3018dbaa1 [Pallas][Mosaic] Expose semaphore read.
PiperOrigin-RevId: 623593440
2024-04-10 13:45:03 -07:00
Henning Becker
9809aa1929 Move CUDA specific functions from asm_compiler to cuda_asm_compiler target
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target

The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.

Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.

PiperOrigin-RevId: 623285804
2024-04-09 14:43:41 -07:00
David Dunleavy
cd2b91c398 Update references to TSL config_settings to their new home in XLA
PiperOrigin-RevId: 623249851
2024-04-09 12:36:10 -07:00
David Dunleavy
d18323f3c4 Move tsl/BUILD, tsl.bzl, and tsl.default.bzl to XLA
PiperOrigin-RevId: 623215553
2024-04-09 10:47:06 -07:00
jax authors
f5cc272615 Merge pull request #20646 from ROCm:rcom-ci-tsl-path-fix
PiperOrigin-RevId: 623129753
2024-04-09 05:09:11 -07:00
Ruturaj4
97bf2d2bb8 [ROCm]: fix tsl path 2024-04-08 19:58:41 -05:00
Christian Sigg
5d54043336 Switch llo and tpu dialects to MLIR properties.
PiperOrigin-RevId: 622760469
2024-04-08 01:05:20 -07:00
Olli Lupton
c97d955771 cuInit before querying compute capability 2024-04-04 15:27:57 +00:00
Marvin Kim
722708052c [JAX] Fix typo in comment.
PiperOrigin-RevId: 621827985
2024-04-04 05:35:28 -07:00
jax authors
fb55d59143 This CL introduces 'PluginProgram' in IFRT and exposes this in python via xla_client.compile_ifrt_program().
The IFRT `PluginProgram` is simply a wrapper for arbitrary byte-strings: an IFRT backend that recognizes `PluginProgram` can interpret the byte-string in any way it sees fit.

PiperOrigin-RevId: 621258245
2024-04-02 12:20:35 -07:00
David Dunleavy
aade591fdf Move tsl/python to xla/tsl/python
PiperOrigin-RevId: 620320903
2024-03-29 13:15:21 -07:00
Peter Hawkins
478cfa9944 Add an upper bound on JAX's CUDNN version constraint.
Major releases of CUDNN break ABI compatibility, so we cannot allow new major versions.

PiperOrigin-RevId: 620030416
2024-03-28 13:00:36 -07:00
Jevin Jiang
67f4f6032a [XLA:Mosaic] Remove duplicate headers in debug assert insertion.
PiperOrigin-RevId: 619801919
2024-03-27 23:14:05 -07:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
jax authors
0be07e6aec Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5

PiperOrigin-RevId: 618911489
2024-03-25 11:46:39 -07:00
Sandeep Dasgupta
6ffd55c405 Fixing StableHLO python dependencies on stablehlo:reference_api
PiperOrigin-RevId: 618294054
2024-03-22 14:52:06 -07:00
jax authors
910a31d7b7 Reverts bed4f65438a62777ed100ecec2b0eb3f7cf87a0e
PiperOrigin-RevId: 618249855
2024-03-22 12:10:53 -07:00
jax authors
bed4f65438 Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

PiperOrigin-RevId: 618195554
2024-03-22 09:05:39 -07:00
Tomás Longeri
7f7e0c00df [Mosaic] Support left shifting relayouts
PiperOrigin-RevId: 618008857
2024-03-21 17:20:30 -07:00
jax authors
2848cda34c Merge pull request #20341 from ROCm:rocm_add_hipStreamWaitEvent
PiperOrigin-RevId: 617893634
2024-03-21 10:41:38 -07:00
Adam Paszke
7d431ad33b Add support for slicing dynamically-shaped memrefs + DMAs between them
This was a little difficult because our current dialect conversion setup assumes 1-1 type conversions.
I think everything works out fine for as long as we never pass memrefs between basic blocks (i.e.
for as long as we never have memrefs as loop carry or return them from conditionals).

TODO: I still need to make sure that the changes to the TPU dialect are backwards-compatible.
I am afraid that the signature change in MemRefSliceOp might not be.
PiperOrigin-RevId: 617755035
2024-03-21 00:56:51 -07:00
Rahul Batra
8575055571 [ROCm]: Add missing hipStreamWaitEvent API call 2024-03-20 16:58:21 +00:00
jax authors
df9cefabc1 jaxlib: Add ifrt_proxy.pyi to build_wheel.py.
PiperOrigin-RevId: 617275734
2024-03-19 13:27:39 -07:00
Peter Hawkins
c2bbf9c577 Remove some code to support older CUDA and CUSPARSE versions.
The minimum CUDA version supported by JAX is CUDA 11.8, which ships with CUSPARSE 11.7.5.

PiperOrigin-RevId: 616892230
2024-03-18 11:25:03 -07:00
Jevin Jiang
7578e10ce3 [XLA:Mosaic] Support dynamic indices in strided load/store.
PiperOrigin-RevId: 615931990
2024-03-14 16:02:08 -07:00
Jevin Jiang
30208fa9cc [XLA:Mosaic] Support strided load/store memref with arbitrary shape as long as last dim size is 128 and dtype is 32bit.
PiperOrigin-RevId: 614862128
2024-03-11 18:22:11 -07:00
Sergei Lebedev
778933dfda Removed inspect.signature() call from jaxlib.triton.dialect.ScanOp
PiperOrigin-RevId: 614772594
2024-03-11 13:30:41 -07:00