20407 Commits

Author SHA1 Message Date
Adam Paszke
9b0319512a [Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
2024-04-26 09:40:47 -07:00
Sergei Lebedev
268b39d426 Added a GPU-specific approximate tanh to Pallas
See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-tanh.

PiperOrigin-RevId: 628424047
2024-04-26 09:15:38 -07:00
Justin Fu
94766b8f2b [Mosaic] Add guard for absl-py in tpu_custom_call.py.
Addresses https://github.com/google/jax/issues/20908.
https://github.com/google/jax/pull/12806 for reference.

PiperOrigin-RevId: 628414523
2024-04-26 08:36:32 -07:00
jax authors
00663474ce Merge pull request #19868 from dan-zheng:jaxpr-docs
PiperOrigin-RevId: 628393328
2024-04-26 06:59:37 -07:00
Adam Paszke
4c3d4323dd [Mosaic GPU] Disable matmul tests in internal CI
PiperOrigin-RevId: 628379779
2024-04-26 05:50:35 -07:00
Sergey Kozub
8738e7e5dc Use correct kWidth in sparse dots with int8 input (on Ampere)
PiperOrigin-RevId: 628368832
2024-04-26 04:53:25 -07:00
George Necula
d92f4ae157 Reverts 9db5e693ebb4ad786c6e52b562cf32aeaba2e7e1
PiperOrigin-RevId: 628362293
2024-04-26 04:14:34 -07:00
John QiangZhang
0b343b9ac1 [jax] Fix jax_export issue with static args.
PiperOrigin-RevId: 628337221
2024-04-26 02:12:24 -07:00
jax authors
c176201386 Update XLA dependency to use revision
822c341cdc.

PiperOrigin-RevId: 628286324
2024-04-25 21:31:11 -07:00
Justin Fu
7067efb6dc Redefines pltpu.trace as an alias of jax.named_scope.
Updates the Pallas TPU lowering function to insert trace start/stop operations to replicate the original functionality of pltpu.trace.

PiperOrigin-RevId: 628254882
2024-04-25 18:34:38 -07:00
Justin Fu
7844bac5d2 Add proper handling of OOB array accesses in pallas interpret mode.
PiperOrigin-RevId: 628202600
2024-04-25 15:05:52 -07:00
Adam Paszke
c6ca1ef204 [Mosaic GPU] Add the first example: pipelined matmul
PiperOrigin-RevId: 628156068
2024-04-25 12:27:25 -07:00
Enrique Piqueras
491618130b Fix jax.tree_util.register_dataclass in older JAX versions.
PiperOrigin-RevId: 628149376
2024-04-25 12:04:41 -07:00
James Lottes
9fd5f7c6a2 Refactor QDWH to be more efficient when run batched under vmap.
In particular, avoid using lax.cond to switch to CholeskyQR for later iterations, as under vmap this can result in both branches being executed.

PiperOrigin-RevId: 628144162
2024-04-25 11:48:21 -07:00
Jake VanderPlas
beb49af678 sparse_nm_test: skip on incompatible GPUs
PiperOrigin-RevId: 628120697
2024-04-25 10:38:07 -07:00
Adam Paszke
ded9272a5b [Mosaic GPU] Implement a simple profilng tool using CUDA events
The other JAX profiling tools are a little heavyweight when we only care about
timing a single kernel programatically.

Also adapt wgmma.py to match failures triggered by upstream MLIR changes.

PiperOrigin-RevId: 628096973
2024-04-25 09:18:39 -07:00
jax authors
fad2c0e315 Merge pull request #20858 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 628061707
2024-04-25 06:58:27 -07:00
Dan Zheng
e9e3c80258 [jaxpr.rst] Remove extraneous 'let' in Jaxpr grammar.
Remove `let` from `Eqn` rule. `let` appears only once in the `Jaxpr` rule.

Consistently capitalize grammar nonterminals: `jaxpr` → `Jaxpr`.
2024-04-25 09:53:16 -04:00
Tomás Longeri
e9fdfc7204 [Mosaic] Fix Python pipeline not working correctly when HLO passes are enabled.
PiperOrigin-RevId: 628044507
2024-04-25 05:36:16 -07:00
jax authors
6ba4dbade5 Merge pull request #20911 from apaszke:mlir-update
PiperOrigin-RevId: 628040111
2024-04-25 05:14:24 -07:00
jax authors
9db5e693eb Reverts 6bfbb4593a42fced91ba50de47271af425c74c20
PiperOrigin-RevId: 628035616
2024-04-25 04:53:22 -07:00
Adam Paszke
340b9e3739 Update GPU and NVGPU MLIR bindings to match upstream MLIR changes
Upstream MLIR Python bindings now require two more extension libraries
to work properly. The dialects fail to import without this change.
2024-04-25 11:41:19 +00:00
Adam Paszke
36c471b6f5 [Mosaic] Add support for concatenating arrays of packed types (<32 bits)
PiperOrigin-RevId: 628001232
2024-04-25 02:04:08 -07:00
George Necula
6bfbb4593a Remove old ducc_fft custom call.
Starting in June 2023 we have switched the CPU lowering for FFT to use
the new custom call dynamic_ducc_fft. We are now out of the backwards
compatibility window and we remove the old ducc_fft.

We need to keep dynamic_ducc_fft a little bit longer (May 2024).

PiperOrigin-RevId: 627981921
2024-04-25 00:29:11 -07:00
jax authors
53c7c3708b Update XLA dependency to use revision
e7c3bc72d2.

PiperOrigin-RevId: 627942055
2024-04-24 20:42:45 -07:00
Jake VanderPlas
cbe48cad1e Finalize deprecation of arr.device_buffer and arr.device_buffers
PiperOrigin-RevId: 627899901
2024-04-24 17:27:25 -07:00
jax authors
66190d10e7 Merge pull request #20917 from jakevdp:flash-fix
PiperOrigin-RevId: 627764834
2024-04-24 10:04:20 -07:00
Jake VanderPlas
4806083d05 pallas flash attention: explicitly use dtype 2024-04-24 08:43:25 -07:00
jax authors
aa350ab7e7 Merge pull request #20896 from pearu:pearu/asin-asinh
PiperOrigin-RevId: 627739066
2024-04-24 08:34:13 -07:00
Adam Paszke
a72a204c39 [Mosaic] Always use 32-bit selects while retiling
Retiling never needs to use packed masks, and those aren't supported on all TPUs.

PiperOrigin-RevId: 627692517
2024-04-24 05:11:58 -07:00
Adam Paszke
5a2d7a2df4 Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.

To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.

There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.

PiperOrigin-RevId: 627670773
2024-04-24 03:27:45 -07:00
Sergey Kozub
aebe82a78f Add JAX API that provides sparse matmul support (2:4 structured sparsity)
Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))

where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type

If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.

PiperOrigin-RevId: 627640553
2024-04-24 01:06:19 -07:00
jax authors
b5fdc0d90f Update XLA dependency to use revision
f0591d62b6.

PiperOrigin-RevId: 627596967
2024-04-23 21:13:06 -07:00
jax authors
b9b5f9b39a Merge pull request #20862 from jakevdp:select_n
PiperOrigin-RevId: 627571544
2024-04-23 18:53:05 -07:00
jax authors
26a3d3dc02 Only perform checks on slice sizes if they're static.
PiperOrigin-RevId: 627560765
2024-04-23 18:02:07 -07:00
Yash Katariya
8239674dab Replace donation_vector's logic with donation_vector_with_in_tree which is now deleted
PiperOrigin-RevId: 627556267
2024-04-23 17:38:30 -07:00
jax authors
8842c0bc91 Merge pull request #20901 from carlosgmartin:lax-scan-xs-none
PiperOrigin-RevId: 627546834
2024-04-23 17:01:24 -07:00
Yash Katariya
3f17626f4b Fix donation with kwargs. The problem is that pytrees sort dictionaries by default. So if we create the donation vector with original kwargs order, it won't match the aval order (which is created by sorting kwargs i.e. dict) and we end up donating the wrong input.
Fix this by calculating the donation vector by looking at the in_tree.

A bonus is that we can now cache the calculation of donation vector leading to faster tracing times in JAX.

PiperOrigin-RevId: 627512710
2024-04-23 14:50:04 -07:00
Paul Wohlhart
6b85557cc1 Use xla_client.Device in jax.numpy.
PiperOrigin-RevId: 627507470
2024-04-23 14:32:08 -07:00
carlosgmartin
2b332de9d7 Let xs=None by default in lax.scan. 2024-04-23 17:26:23 -04:00
Enrique Piqueras
cf9c08589e Add builtin cc dataclass pytree node for performance.
PiperOrigin-RevId: 627502102
2024-04-23 14:14:49 -07:00
jax authors
8b1418244b Merge pull request #20885 from rajasekharporeddy:test_branch4
PiperOrigin-RevId: 627486343
2024-04-23 13:29:40 -07:00
jax authors
8ead2df7bb Merge pull request #20897 from jakevdp:doc-updates
PiperOrigin-RevId: 627486188
2024-04-23 13:24:43 -07:00
Jake VanderPlas
a8af2b788a DOC: respond to mattjj comments 2024-04-23 13:15:20 -07:00
rajasekharporeddy
c536eea1e5 Fix jax.scipy.stats.beta.logpdf to emulate scipy.stats.beta.logpdf 2024-04-24 01:24:09 +05:30
jax authors
ba57ce3bd1 Merge pull request #20891 from rajasekharporeddy:test_branch1
PiperOrigin-RevId: 627472771
2024-04-23 12:39:18 -07:00
Peter Hawkins
ab30bcf071 [jax2tf] Bump asinh test tolerance in graph and eager modes.
Fixes CI test failure due to LLVM update.

PiperOrigin-RevId: 627462404
2024-04-23 12:03:00 -07:00
rajasekharporeddy
95ed0538fd Fix jax.scipy.stats.poisson.logpmf to emulate scipy.stats.poisson.logpmf for non-integer values of k 2024-04-24 00:29:52 +05:30
Pearu Peterson
e8ff7028f4 Workaround mpmath 1.3 issues in asin and asinh evaluation at infinities and on branch cuts. 2024-04-23 21:01:43 +03:00
Yunlong Liu
2df6b35dce Adds meaningful function names for better debugging.
The default `fn.__name__` was added in `_one_to_one_unop` but not other functions so that it leads to many downstream function wrappers giving unmeaningful names while debugging. For instance,

When a JAX numpy primitive `lax.add` is wrapped by `lu.WrappedFun`, `print(wrapped)` will give,

```
Wrapped function:
0   : _argnums_partial   ((0, 1), ())
1   : flatten_fun   (PyTreeDef(((*, *), {})),)
2   : result_paths   ()
Core: fn
```
instead of
```
Wrapped function:
0   : _argnums_partial   ((0, 1), ())
1   : flatten_fun   (PyTreeDef(((*, *), {})),)
2   : result_paths   ()
Core: add
```
PiperOrigin-RevId: 627417452
2024-04-23 09:45:57 -07:00