26078 Commits

Author SHA1 Message Date
jax authors
4cab118344 Merge pull request #26927 from skye:merge_release
PiperOrigin-RevId: 734323206
2025-03-06 16:06:09 -08:00
jax authors
cd7f03f272 Updates the Colocated Python's serialization (and deserialization) implementation to utilize the recently added support for string arrays.
Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.

PiperOrigin-RevId: 734299259
2025-03-06 14:57:52 -08:00
Jevin Jiang
4b49c03523 Open source TPU-friendly ragged paged attention kernel.
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.

PiperOrigin-RevId: 734269519
2025-03-06 13:36:45 -08:00
Dimitar (Mitko) Asenov
5d64b3d2dd [Mosaic GPU] Fix scf.ForOp lowering to put lowered ops at the right place.
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.

The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.

PiperOrigin-RevId: 734157829
2025-03-06 08:40:19 -08:00
Ayaka
8c89da7cdc Minor bug fixes in error checking
PiperOrigin-RevId: 734126415
2025-03-06 06:57:52 -08:00
Nitin Srinivasan
623865fe95 Build JAX wheels instead of installing it from the source repository
This change allows us to get rid of extra env vars which used to control whether to install `jax` at head. Now, `jax` will be be built and consumed in the same way as the other wheels in the continuous jobs.

PiperOrigin-RevId: 734123590
2025-03-06 06:48:16 -08:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Chris Jones
d6b97c2026 [pallas] Add support for pl.dot with int8 inputs.
PiperOrigin-RevId: 734081057
2025-03-06 04:08:04 -08:00
jax authors
16bb919020 Update XLA dependency to use revision
6e396aae2e.

PiperOrigin-RevId: 734059108
2025-03-06 02:40:28 -08:00
Benjamin Chetioui
fe577b5dc4 [Pallas/Mosaic GPU] Enable ops_test for Mosaic GPU.
For now, most of the tests are skipped.

PiperOrigin-RevId: 734026728
2025-03-06 00:45:05 -08:00
Yash Katariya
a67ab9fade Just use jit as the string in error messages instead of jit and pjit based on resource_env. This is to start deprecating the need for with mesh and replace it with use_mesh(mesh).
PiperOrigin-RevId: 733959962
2025-03-05 20:09:30 -08:00
Yash Katariya
ba5349f896 Add a note about uneven sharding and with_sharding_constraint. Fixes https://github.com/jax-ml/jax/issues/26946
PiperOrigin-RevId: 733953836
2025-03-05 19:35:03 -08:00
jax authors
c16f37d89d Set USERPROFILE for Windows builds to fix CI issue.
This change fixes https://github.com/jax-ml/jax/actions/runs/13686468791/job/38270929632.

From the [documentation](https://docs.python.org/3/library/os.path.html#os.path.expanduser):
`On Windows, USERPROFILE will be used if set, otherwise a combination of HOMEPATH and HOMEDRIVE will be used.`

PiperOrigin-RevId: 733935305
2025-03-05 18:09:14 -08:00
Jacob Burnim
016b351f00 [Pallas] Adds a simple dynamic race detector for TPU interpret mode.
PiperOrigin-RevId: 733885890
2025-03-05 15:15:21 -08:00
jax authors
8571ad9ff2 Merge pull request #26952 from garymm:vmap-arg
PiperOrigin-RevId: 733865978
2025-03-05 14:19:11 -08:00
jax authors
0913cd7583 Fix build rule for free-threaded python builds.
PiperOrigin-RevId: 733857126
2025-03-05 13:54:24 -08:00
Gary Miguel
69d66f66df vmap mismatch size error message: handle *args
Fixes: https://github.com/jax-ml/jax/issues/26908
2025-03-05 13:08:54 -08:00
jax authors
3edc068f8c Fix ambiguous cpu definition for JAX wheels.
Should fix the error in https://github.com/jax-ml/jax/actions/runs/13682579939/job/38258344926.

PiperOrigin-RevId: 733838895
2025-03-05 12:59:21 -08:00
Adam Paszke
8df00e2666 [Mosaic GPU] Remove support for large tiles on Blackwell
We don't have many Blackwell kernels yet, so let's begin the deprecation there!
Small tiles have clearer semantics when it comes to transposes too, which allows
us to enable more test cases.

PiperOrigin-RevId: 733786884
2025-03-05 10:34:53 -08:00
jax authors
1ae7dd7f76 Update .bazelrc with Apple CC toolchain changes.
PiperOrigin-RevId: 733784816
2025-03-05 10:31:16 -08:00
Dan Foreman-Mackey
4a93c8b30c Reverts 342cb7b99a09180472823a33c7cdad8a8db77875
PiperOrigin-RevId: 733782497
2025-03-05 10:22:40 -08:00
Adam Paszke
4493889cda [Mosaic GPU] Add support for small tiles for (WG)MMA LHS
Thanks to the previous refactor the change is quite trivial and mostly
focuses on adding tests.

PiperOrigin-RevId: 733754797
2025-03-05 09:01:20 -08:00
jax authors
4e1f969a76 Merge pull request #26934 from hawkinsp:tsan
PiperOrigin-RevId: 733738121
2025-03-05 08:08:29 -08:00
Adam Paszke
d119138766 [Mosaic GPU][NFC] Refactor MMA SMEM descriptor creation
This makes the code path uniform for LHS/RHS and greatly clarifies the
magical computation of LBO/SBO. This change should make it significantly
easier for us to enable small tile support for the LHS.

PiperOrigin-RevId: 733737302
2025-03-05 08:06:06 -08:00
jax authors
9c19afd9b6 Merge pull request #26938 from superbobry:maint-2
PiperOrigin-RevId: 733725386
2025-03-05 07:29:16 -08:00
Sergei Lebedev
6230ef1d51 Removed unused import 2025-03-05 15:18:43 +00:00
jax authors
a13b3cedad Merge pull request #26691 from h-vetinari:packed
PiperOrigin-RevId: 733696873
2025-03-05 05:46:01 -08:00
Peter Hawkins
40e1a2a561 Remove a TSAN suppression.
https://github.com/python/cpython/issues/130547 has been marked as fixed and backported to 3.13, so this suppression should no longer be necessary.
2025-03-05 08:39:58 -05:00
jax authors
f3b2c84126 Merge pull request #26627 from Cjkkkk:remove_fmha_rewriter
PiperOrigin-RevId: 733690769
2025-03-05 05:20:25 -08:00
Dan Foreman-Mackey
342cb7b99a Attempt 2 at landing custom_vjp.optimize_remat using custom_dce.
The original change was rolled back because there were real world use cases of custom_vjp where the fwd function had the wrong signature. To preserve backwards compatibility, we shouldn't resolve the input arguments to fwd using fwds signature. Instead, we can just ignore the signature because custom_vjp handles the resolution before we ever get here.

Reverts 1f3176636d304398b00a7d2cb0933859618affd8

PiperOrigin-RevId: 733643149
2025-03-05 02:06:35 -08:00
jax authors
06b760eea2 Update XLA dependency to use revision
e0e56a1190.

PiperOrigin-RevId: 733636860
2025-03-05 01:43:45 -08:00
Christos Perivolaropoulos
51719a1afe [mgpu] Non-vector untiled stores for tiling layouts.
Useful for storing in memrefs where the minormost stride is >1.

PiperOrigin-RevId: 733551038
2025-03-04 19:41:04 -08:00
Skye Wanderman-Milne
cebedb9f1a Update version number after 0.5.2 release 2025-03-04 18:49:12 -08:00
Skye Wanderman-Milne
a6c858f04b Merge branch 'release/0.5.2' into main 2025-03-04 18:47:20 -08:00
Yash Katariya
766315f791 Make sure concat + vmap of sharded input and replicated input works properly.
In this case, the example boils down to:

```
inp1 = f32[16@x, 4]
inp2 = f32[4]

def f(x: f32[4], y: f32[4])
  return jnp.concat([x, y], axis=-1)

vmap(f, in_axes=(0, None))(inp1)
```

This example was breaking in concat batching rule because we didn't broadcast with the right sharding.

PiperOrigin-RevId: 733536944
2025-03-04 18:35:13 -08:00
Jake Harmon
cdeeacabcf Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 733536104
2025-03-04 18:31:09 -08:00
David Dunleavy
1a19d5594a Update all uses of @tsl//third_party to @xla//third_party
PiperOrigin-RevId: 733495240
2025-03-04 15:55:23 -08:00
jax authors
c145102ef4 Merge pull request #26641 from jakevdp:jnp-ndim
PiperOrigin-RevId: 733484459
2025-03-04 15:21:01 -08:00
jax authors
b238bad703 Merge pull request #26901 from NeilGirdhar:etils
PiperOrigin-RevId: 733466732
2025-03-04 14:28:51 -08:00
Nitin Srinivasan
721d1a3211 Add functionality to allow promoting RC wheels during release
List of changes:
1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule.
2. Change the upload script to upload both rc and release tagged wheels (changes internal)

PiperOrigin-RevId: 733464219
2025-03-04 14:21:12 -08:00
Skye Wanderman-Milne
ce224293b1 Prepare for JAX release 0.5.2 (patch release over 0.5.1) 2025-03-04 12:59:24 -08:00
Skye Wanderman-Milne
bb80a56898 Update setup.py to automatically pick up libtpu patch releases 2025-03-04 12:32:58 -08:00
Gleb Pobudzey
43b6be0e81 [Mosaic GPU] Add lowering for log, and a fast path using log2.
PiperOrigin-RevId: 733411276
2025-03-04 11:50:50 -08:00
Kanglan Tang
d112c85e6d Internal config change
PiperOrigin-RevId: 733398579
2025-03-04 11:17:48 -08:00
Jake VanderPlas
8cec6e636a jax.numpy ndim/shape/size: deprecate non-array input 2025-03-04 10:42:32 -08:00
jax authors
8af6f70fe0 [JAX] Disable msan and asan for the profiler test running on nvidia gpu
PiperOrigin-RevId: 733380848
2025-03-04 10:34:11 -08:00
jax authors
7d0aab5a98 Merge pull request #26916 from jakevdp:update-array-api-tests
PiperOrigin-RevId: 733379865
2025-03-04 10:32:25 -08:00
jax authors
4a73134b2f Merge pull request #26912 from dfm:resolve-args-error-message
PiperOrigin-RevId: 733378431
2025-03-04 10:26:43 -08:00
Jake VanderPlas
f0bbd26d03 Update array-api-tests to latest commit 2025-03-04 10:17:51 -08:00
Neil Girdhar
52ab8c4cc2 Fix detection of epath
Unfortunately, the old detection code doesn't guarantee that `epath` is
installed:
```
[utM] In [7]: importlib.util.find_spec("etils.epath")
Out[7]: ModuleSpec(name='etils.epath',
loader=<_frozen_importlib_external.SourceFileLoader object at
0x73b8492a7230>,
origin='/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath/__init__.py',
submodule_search_locations=['/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath'])

[utM] In [8]: import etils.epath
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent
call last)
Cell In[8], line 1
----> 1 import etils.epath
...
ModuleNotFoundError: No module named 'importlib_resources'
```
This happened every time I ran jax with a clean environment.
2025-03-04 11:44:27 -05:00