16367 Commits

Author SHA1 Message Date
Zac Mustin
8095d842c8 roofline: Support computing flops for unary ops.
PiperOrigin-RevId: 734351741
2025-03-06 17:44:36 -08:00
Yash Katariya
e9486920e8 Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with None. So that for a 2D input, P('data') continues to work.
PiperOrigin-RevId: 734325209
2025-03-06 16:10:14 -08:00
jax authors
4cab118344 Merge pull request #26927 from skye:merge_release
PiperOrigin-RevId: 734323206
2025-03-06 16:06:09 -08:00
jax authors
cd7f03f272 Updates the Colocated Python's serialization (and deserialization) implementation to utilize the recently added support for string arrays.
Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.

PiperOrigin-RevId: 734299259
2025-03-06 14:57:52 -08:00
Jake VanderPlas
b441b2b7a5 Prevent tracer leaks in scipy.special.expn 2025-03-06 14:38:11 -08:00
Jevin Jiang
4b49c03523 Open source TPU-friendly ragged paged attention kernel.
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.

PiperOrigin-RevId: 734269519
2025-03-06 13:36:45 -08:00
Dimitar (Mitko) Asenov
5d64b3d2dd [Mosaic GPU] Fix scf.ForOp lowering to put lowered ops at the right place.
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.

The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.

PiperOrigin-RevId: 734157829
2025-03-06 08:40:19 -08:00
Ayaka
8c89da7cdc Minor bug fixes in error checking
PiperOrigin-RevId: 734126415
2025-03-06 06:57:52 -08:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Chris Jones
d6b97c2026 [pallas] Add support for pl.dot with int8 inputs.
PiperOrigin-RevId: 734081057
2025-03-06 04:08:04 -08:00
Yash Katariya
a67ab9fade Just use jit as the string in error messages instead of jit and pjit based on resource_env. This is to start deprecating the need for with mesh and replace it with use_mesh(mesh).
PiperOrigin-RevId: 733959962
2025-03-05 20:09:30 -08:00
Yash Katariya
ba5349f896 Add a note about uneven sharding and with_sharding_constraint. Fixes https://github.com/jax-ml/jax/issues/26946
PiperOrigin-RevId: 733953836
2025-03-05 19:35:03 -08:00
Jacob Burnim
016b351f00 [Pallas] Adds a simple dynamic race detector for TPU interpret mode.
PiperOrigin-RevId: 733885890
2025-03-05 15:15:21 -08:00
Gary Miguel
69d66f66df vmap mismatch size error message: handle *args
Fixes: https://github.com/jax-ml/jax/issues/26908
2025-03-05 13:08:54 -08:00
Owen Lockwood
3e4dc0d490 add pmap axes hints 2025-03-05 12:14:24 -08:00
Adam Paszke
8df00e2666 [Mosaic GPU] Remove support for large tiles on Blackwell
We don't have many Blackwell kernels yet, so let's begin the deprecation there!
Small tiles have clearer semantics when it comes to transposes too, which allows
us to enable more test cases.

PiperOrigin-RevId: 733786884
2025-03-05 10:34:53 -08:00
Dan Foreman-Mackey
4a93c8b30c Reverts 342cb7b99a09180472823a33c7cdad8a8db77875
PiperOrigin-RevId: 733782497
2025-03-05 10:22:40 -08:00
shuw
c099e8081d support e2m1fn 2025-03-05 17:44:34 +00:00
Adam Paszke
4493889cda [Mosaic GPU] Add support for small tiles for (WG)MMA LHS
Thanks to the previous refactor the change is quite trivial and mostly
focuses on adding tests.

PiperOrigin-RevId: 733754797
2025-03-05 09:01:20 -08:00
Adam Paszke
d119138766 [Mosaic GPU][NFC] Refactor MMA SMEM descriptor creation
This makes the code path uniform for LHS/RHS and greatly clarifies the
magical computation of LBO/SBO. This change should make it significantly
easier for us to enable small tile support for the LHS.

PiperOrigin-RevId: 733737302
2025-03-05 08:06:06 -08:00
jax authors
f3b2c84126 Merge pull request #26627 from Cjkkkk:remove_fmha_rewriter
PiperOrigin-RevId: 733690769
2025-03-05 05:20:25 -08:00
Dan Foreman-Mackey
342cb7b99a Attempt 2 at landing custom_vjp.optimize_remat using custom_dce.
The original change was rolled back because there were real world use cases of custom_vjp where the fwd function had the wrong signature. To preserve backwards compatibility, we shouldn't resolve the input arguments to fwd using fwds signature. Instead, we can just ignore the signature because custom_vjp handles the resolution before we ever get here.

Reverts 1f3176636d304398b00a7d2cb0933859618affd8

PiperOrigin-RevId: 733643149
2025-03-05 02:06:35 -08:00
Christos Perivolaropoulos
51719a1afe [mgpu] Non-vector untiled stores for tiling layouts.
Useful for storing in memrefs where the minormost stride is >1.

PiperOrigin-RevId: 733551038
2025-03-04 19:41:04 -08:00
Skye Wanderman-Milne
cebedb9f1a Update version number after 0.5.2 release 2025-03-04 18:49:12 -08:00
Yash Katariya
766315f791 Make sure concat + vmap of sharded input and replicated input works properly.
In this case, the example boils down to:

```
inp1 = f32[16@x, 4]
inp2 = f32[4]

def f(x: f32[4], y: f32[4])
  return jnp.concat([x, y], axis=-1)

vmap(f, in_axes=(0, None))(inp1)
```

This example was breaking in concat batching rule because we didn't broadcast with the right sharding.

PiperOrigin-RevId: 733536944
2025-03-04 18:35:13 -08:00
Jake Harmon
cdeeacabcf Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 733536104
2025-03-04 18:31:09 -08:00
jax authors
c145102ef4 Merge pull request #26641 from jakevdp:jnp-ndim
PiperOrigin-RevId: 733484459
2025-03-04 15:21:01 -08:00
jax authors
b238bad703 Merge pull request #26901 from NeilGirdhar:etils
PiperOrigin-RevId: 733466732
2025-03-04 14:28:51 -08:00
Gleb Pobudzey
43b6be0e81 [Mosaic GPU] Add lowering for log, and a fast path using log2.
PiperOrigin-RevId: 733411276
2025-03-04 11:50:50 -08:00
Jake VanderPlas
8cec6e636a jax.numpy ndim/shape/size: deprecate non-array input 2025-03-04 10:42:32 -08:00
jax authors
4a73134b2f Merge pull request #26912 from dfm:resolve-args-error-message
PiperOrigin-RevId: 733378431
2025-03-04 10:26:43 -08:00
Neil Girdhar
52ab8c4cc2 Fix detection of epath
Unfortunately, the old detection code doesn't guarantee that `epath` is
installed:
```
[utM] In [7]: importlib.util.find_spec("etils.epath")
Out[7]: ModuleSpec(name='etils.epath',
loader=<_frozen_importlib_external.SourceFileLoader object at
0x73b8492a7230>,
origin='/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath/__init__.py',
submodule_search_locations=['/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath'])

[utM] In [8]: import etils.epath
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent
call last)
Cell In[8], line 1
----> 1 import etils.epath
...
ModuleNotFoundError: No module named 'importlib_resources'
```
This happened every time I ran jax with a clean environment.
2025-03-04 11:44:27 -05:00
jax authors
97db925a7d Merge pull request #26765 from Qwlouse:patch-1
PiperOrigin-RevId: 733339465
2025-03-04 08:30:45 -08:00
Adam Paszke
cdae5fcfc7 [Mosaic GPU] Make sure to do the async proxy fence before wargroup sync
This is the ordering we want for a proper release of generic SMEM stores
into the async proxy. The old order was problematic: once the warpgroup
barrier was complete, some warps could get deselected before they get to
the fence. For as long as the first warp would make progress, it could go
through the fence along and start issuing TMA copies before other warps
have synchronized with the async proxy.

I have not observed this problem in any of our kernels so far, but this
order seems safer to me.

PiperOrigin-RevId: 733333814
2025-03-04 08:11:15 -08:00
Dan Foreman-Mackey
8b1b039e0d Improve error messages when input argument resolution fails in custom_* APIs. 2025-03-04 10:31:35 -05:00
Sergei Lebedev
155839bb4d [pallas:triton] Emit a better error message for matmul with non-2D operands
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.

Fixes #26013.

PiperOrigin-RevId: 733293299
2025-03-04 05:46:29 -08:00
Dan Foreman-Mackey
6c5ef1a404 Update jnp.unique to support upstream interface changes. 2025-03-04 05:24:52 -05:00
Ayaka
ea53c7616b Fix thread safety of JAX error checking
Fix thread safety of JAX error checking by making the global states thread local

PiperOrigin-RevId: 733164878
2025-03-03 20:56:01 -08:00
Sharad Vikram
00d9f4529d [Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
PiperOrigin-RevId: 733122108
2025-03-03 17:43:13 -08:00
Sharad Vikram
d32e282ff9 Add fuser to jax.experimental.pallas
Note that fuser is considered experimental within Pallas and APIs are subject to change

PiperOrigin-RevId: 733117882
2025-03-03 17:26:44 -08:00
Sharad Vikram
0b6c355083 [Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels
PiperOrigin-RevId: 733112191
2025-03-03 17:05:51 -08:00
jax authors
2c7043f63d Merge pull request #26865 from jakevdp:fix-indexing-error
PiperOrigin-RevId: 733085471
2025-03-03 15:38:20 -08:00
jax authors
f9f47217df Merge pull request #26862 from jakevdp:logsumexp-docs
PiperOrigin-RevId: 733080943
2025-03-03 15:24:10 -08:00
jax authors
4944dcb977 Merge pull request #26897 from jakevdp:cond-doc
PiperOrigin-RevId: 733077065
2025-03-03 15:13:23 -08:00
jax authors
07d1cd0290 Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_matrix
PiperOrigin-RevId: 733077011
2025-03-03 15:11:31 -08:00
Jake VanderPlas
84ca80d215 doc: in lax.cond, note that both branches will be traced 2025-03-03 13:05:24 -08:00
Peter Hawkins
7f05b74bca Fix wrong results in multidimensional pad.
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.

We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.

Fixes https://github.com/jax-ml/jax/issues/26888
2025-03-03 15:25:08 -05:00
carlosgmartin
897e1a1310 Fix linalg.norm to return zero for proper norms of empty matrices. 2025-03-03 15:02:34 -05:00
Adam Paszke
e9f95cc3a7 [Mosaic GPU] Make the small WGMMA tile independent of transpose flags
Now the small tiling is always `(8, swizzle // bytewidth(dtype))`, no matter whether the input
is transposed or not. This should simply the follow-up refactoring of the code and make it easier
to enable small tiling for LHS too.

PiperOrigin-RevId: 732933005
2025-03-03 08:30:57 -08:00
Bart Chrzaszcz
ed4a7bbab1 #sdy Add JAX backwards compatibility test.
This tests saving a module with one set of axis names, but loading it with another set of axis names.

This does also test the custom calls:

- `@Sharding`
- `@xla.sdy.GlobalToLocalShape`
- `@xla.sdy.LocalToGlobalShape`

But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set `out_shardings` for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO.

PiperOrigin-RevId: 732893432
2025-03-03 06:01:34 -08:00