The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...
With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.
PiperOrigin-RevId: 628430358
Updates the Pallas TPU lowering function to insert trace start/stop operations to replicate the original functionality of pltpu.trace.
PiperOrigin-RevId: 628254882
In particular, avoid using lax.cond to switch to CholeskyQR for later iterations, as under vmap this can result in both branches being executed.
PiperOrigin-RevId: 628144162
The other JAX profiling tools are a little heavyweight when we only care about
timing a single kernel programatically.
Also adapt wgmma.py to match failures triggered by upstream MLIR changes.
PiperOrigin-RevId: 628096973
Starting in June 2023 we have switched the CPU lowering for FFT to use
the new custom call dynamic_ducc_fft. We are now out of the backwards
compatibility window and we remove the old ducc_fft.
We need to keep dynamic_ducc_fft a little bit longer (May 2024).
PiperOrigin-RevId: 627981921
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.
To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.
There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.
PiperOrigin-RevId: 627670773
Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))
where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type
If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.
PiperOrigin-RevId: 627640553
Fix this by calculating the donation vector by looking at the in_tree.
A bonus is that we can now cache the calculation of donation vector leading to faster tracing times in JAX.
PiperOrigin-RevId: 627512710
The default `fn.__name__` was added in `_one_to_one_unop` but not other functions so that it leads to many downstream function wrappers giving unmeaningful names while debugging. For instance,
When a JAX numpy primitive `lax.add` is wrapped by `lu.WrappedFun`, `print(wrapped)` will give,
```
Wrapped function:
0 : _argnums_partial ((0, 1), ())
1 : flatten_fun (PyTreeDef(((*, *), {})),)
2 : result_paths ()
Core: fn
```
instead of
```
Wrapped function:
0 : _argnums_partial ((0, 1), ())
1 : flatten_fun (PyTreeDef(((*, *), {})),)
2 : result_paths ()
Core: add
```
PiperOrigin-RevId: 627417452