27018 Commits

Author SHA1 Message Date
Peter Hawkins
c69e61e1a9 Remove jax.lib.xla_client.{XlaComputation,Shape}.
PiperOrigin-RevId: 746803082
2025-04-12 06:18:02 -07:00
Roy Frostig
566d0775a8 unify stages.Lowering and stages.XlaLowering
We no longer have many different implicit types conforming to `Lowering`, only `pxla.MeshComputation` and `pxla.PmapComputation`. Both are `XlaLowering` subtypes. So define just one common base class, call it `Lowering`, and inherit from just that in both concrete internal computation/lowering subtypes.

PiperOrigin-RevId: 746735857
2025-04-12 00:31:14 -07:00
Roy Frostig
99ca14601d revert making Executable an ABC
PiperOrigin-RevId: 746726071
2025-04-11 23:49:25 -07:00
Yash Katariya
4ff78e6a0e Remove various methods from MeshExecutable
These are thin and their implementations can be inlined directly at call sites in `XlaExecutable`.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 746716734
2025-04-11 23:02:54 -07:00
Roy Frostig
19d3d954bf unify stages.Executable and stages.XlaExecutable
We no longer have many different implicit types conforming to `Executable`, only `pxla.MeshExectuable` and `pxla.PmapExecutable`. Both are `XlaExecutable` subtypes. So define just one common base class, call it `Exectuable`, and inherit from just that in both concrete internal executable subtypes.

PiperOrigin-RevId: 746706712
2025-04-11 22:09:47 -07:00
George Necula
dc10200906 [explain-cache-miss] Improve the detection of user file names
When we print explanations for tracing cache misses,
we use traceback_util to ignore JAX-internal functions.
Here we change the detection mechanism to use
source_info_util, which has a more exhaustive
list of JAX internals.

This removes a lot of uninteresting explanations
from a large benchmark.

jax-fixit

PiperOrigin-RevId: 746703003
2025-04-11 21:53:55 -07:00
jax authors
e1cad34522 Add ChunkedCausalMask for Splash Attention to support attention masking similar to Llama4.
Llama4 uses (interleaved) chunk attention to support long context.

PiperOrigin-RevId: 746661156
2025-04-11 18:59:07 -07:00
Yash Katariya
8afc833c24 Rename is_closed to is_open in the shardy shardings
PiperOrigin-RevId: 746645422
2025-04-11 17:42:34 -07:00
jax authors
1a4a86aa48 Merge pull request #27970 from mattjj:while-readonly-carry-optimization
PiperOrigin-RevId: 746639385
2025-04-11 17:15:30 -07:00
Matthew Johnson
29f65f04ed re-index jaxpr input effects in move_binders_to_front 2025-04-11 23:50:14 +00:00
Jevin Jiang
0fa732ea45 [ragged-paged-attn][NFC] Make validate_inputs functions take same inputs as attention call.
PiperOrigin-RevId: 746616128
2025-04-11 15:49:53 -07:00
Matthew Johnson
b3f49e42d9 Re-landing #27937 with fewer bugs and more tests. 2025-04-11 22:42:08 +00:00
jax authors
b2a8df7183 Add the method argument to jax.numpy.isin stub.
This parameter is available from https://github.com/google/jax/pull/23040 and documented in https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isin.html.

PiperOrigin-RevId: 746606206
2025-04-11 15:15:22 -07:00
Peter Hawkins
6fc78a5a6d Deprecate jax.lax.infeed and jax.lax.outfeed.
These APIs are already broken on GPU and TPU by virtue of not being implemented in the PJRT C API, so it seems unlikely that they have any users.

PiperOrigin-RevId: 746595857
2025-04-11 14:42:14 -07:00
Parker Schuh
c0d97a6872 Removed type annotations appear to be used and actually defined in python as a patch, rolling back.
Reverts b1c96d47ed9876a74ee2686234201aacd7cd7791

PiperOrigin-RevId: 746565341
2025-04-11 13:10:03 -07:00
Yash Katariya
6efcf44b1a Deprecate PositionalSharding and GSPMDSharding
PiperOrigin-RevId: 746564071
2025-04-11 13:06:43 -07:00
Matthew Johnson
e9364f4b0a Reverts 907725dfd7a7fb612c4f6d975bb462f1ae1a21d7
PiperOrigin-RevId: 746554582
2025-04-11 12:37:20 -07:00
Peter Hawkins
904419cb0e Rename TPU bazel test tags.
Use a count of chips (or omit it if 1) rather than specifying an ICI topology.

Examples:
* tpu_v5e_1x1 -> tpu_v5e
* tpu_v5e_4x2 -> tpu_v5e_x8
PiperOrigin-RevId: 746547477
2025-04-11 12:15:15 -07:00
Justin Fu
27c07f7cd3 [Pallas] Allow 1D iota
PiperOrigin-RevId: 746546870
2025-04-11 12:13:33 -07:00
Nitin Srinivasan
5cf74cc72b Use dash instead of underscore for extras.
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash.

https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use

This should fix the following error: https://github.com/jax-ml/jax/issues/27874  from appearing in future JAX releases

PiperOrigin-RevId: 746546162
2025-04-11 12:11:38 -07:00
jax authors
8e9fca1d08 document SPMD pipeline parallelism
PiperOrigin-RevId: 746543312
2025-04-11 12:03:45 -07:00
Peter Hawkins
ab88273596 Deprecate jax.dlpack.to_dlpack.
This is not needed under the newer DLPack protocol for users, and there's an equivalent (`__dlpack__`).

PiperOrigin-RevId: 746530351
2025-04-11 11:26:20 -07:00
Yash Katariya
a39b6232be Make sure the order passed to make_jit and _parse_jit_arguments is the same as the order of arguments received in jit API and make it keyword-only
PiperOrigin-RevId: 746527807
2025-04-11 11:18:59 -07:00
Parker Schuh
b1c96d47ed Remove unused execute_sharded_* functions.
PiperOrigin-RevId: 746520758
2025-04-11 10:58:43 -07:00
George Necula
5adac1cb8a Fix the printing of the function name in tracing-cache-miss explanations
jax-fixit

PiperOrigin-RevId: 746496570
2025-04-11 09:53:57 -07:00
Peter Hawkins
3736e5ba85 Bump the JAX version to v0.6.0, which will be the next release version.
PiperOrigin-RevId: 746490665
2025-04-11 09:34:42 -07:00
Peter Hawkins
8b7319afe9 [JAX] Remove calls to jax.dlpack.to_dlpack(), and avoid passing DLPack capsules to jax.dlpack.from_dlpack().
to_dlpack() is not needed in the current version of the dlpack protocol. The from_dlpack() method accepts an object that implements __dlpack__(). In most cases, a JAX array can be passed directly to functions like torch.dlpack.from_dlpack(), and vice versa for other frameworks. The main exception is TensorFlow which does not implement the current protocol.

PiperOrigin-RevId: 746464890
2025-04-11 08:09:41 -07:00
jax authors
b3c0ec0486 Update XLA dependency to use revision
ca9011742b.

PiperOrigin-RevId: 746448816
2025-04-11 07:11:55 -07:00
Sergei Lebedev
d543df1324 [pallas:mosaic_gpu] Added support for unroll=True to the lax.fori_loop lowering
PiperOrigin-RevId: 746444372
2025-04-11 06:56:05 -07:00
Peter Hawkins
614ef37ce7 Fix test flakiness in tpu_pallas_test when JAX_TEST_NUM_THREADS > 1.
stdout redirection is inherently racy; mark test cases doing it as thread unsafe.

PiperOrigin-RevId: 746443039
2025-04-11 06:51:52 -07:00
George Necula
8082186fa7 Fix api_test on persistent cache enabled platform
Follow-up from https://github.com/jax-ml/jax/pull/27916.
jax-fixit

PiperOrigin-RevId: 746442635
2025-04-11 06:49:51 -07:00
Peter Hawkins
b49972d1ce Move test skip for unary_ops_accuracy_test to a setUp method.
The skip decorator being used here only worked for test methods, not test classes, so it accidentally had the effect of skipping all the tests.
But we don't really need a special decorator here anyway.

PiperOrigin-RevId: 746434607
2025-04-11 06:19:45 -07:00
Henning Becker
896557f07b Register NVPTX LLVM backend from Mosaic custom call
So far Mosaic was implicitly relying on XLA to register the NVPTX target which made problems in cases where only a Mosaic kernel gets compiled and XLA didn't initialize the LLVM NVPTX target.

PiperOrigin-RevId: 746433654
2025-04-11 06:15:34 -07:00
jax authors
a1c06fcb3b Merge pull request #27873 from gnecula:aot_wraps2
PiperOrigin-RevId: 746425307
2025-04-11 05:43:38 -07:00
George Necula
7eb397d1e5 Make trace and lower class attributes for jax.jit.
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use:

```
jax.jit(f).trace(...)
```

The new attributes create problems when `jax.jit` is used along `functools.wraps`.
Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a
function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`.
This works as expected if you just call the result, but if you try to use it with
`lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the
attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function,
so when `trace` is invoked the `wrapper` is bypassed entirely.

See #27829 and #27825.

The solution proposed here is to make the `trace` and `lower` be class attributes,
so that they are not copied by `functools.wraps`.
Thus, if you try to use `lower` or `trace` on the result of
`functools.wraps(jax.jit(f))()` you will get an error.
That is better than silently ignoring the wrapper.
The workaround is to apply `jax.jit` last among your wrappers.

Fixes: #27829
2025-04-11 14:51:12 +03:00
jax authors
c9cbf82164 Merge pull request #27876 from gnecula:aot_compute_on
PiperOrigin-RevId: 746402180
2025-04-11 04:08:18 -07:00
jax authors
1035c9a118 Merge pull request #27916 from gnecula:tracing_cache_ignore_internals
PiperOrigin-RevId: 746397452
2025-04-11 03:53:47 -07:00
jax authors
ac285a138b Merge pull request #27685 from Cjkkkk:return_cudnn_sdpa_residual
PiperOrigin-RevId: 746397395
2025-04-11 03:51:40 -07:00
Dan Foreman-Mackey
81722201fd Remove legacy CPU custom call kernels that have been unused since v0.4.34.
As of today it has been 180 days since the release of 0.4.34 where the following legacy LAPACK kernels were no longer used when lowering:

* getrf
* geqrf / orgqr
* potrf
* gesdd
* syevd
* geev
* gehrd

Following our compatibility policy, these are now safe to remove.

PiperOrigin-RevId: 746388529
2025-04-11 03:17:19 -07:00
George Necula
96d38a6b66 [cache_misses] Skip tracing-cache-miss explanations for JAX internal functions
About half of the tracing-cache-miss explanations in a large benchmark
end up being from JAX-internal functions, such as `jax.numpy` functions.
These cache misses are not what the JAX user wants to see, so we filter
them out, using the same mechanism used for filtering tracebacks.
2025-04-11 12:53:38 +03:00
jax authors
d42d2e88b4 [Pallas] Interpret dimensions with parallel semantics by traversing the corresponding grid coordinates in randomized order.
Note that dynamic grid dimensions with 'parallel' semantics are disallowed. This enables the computation of grid points, with randomized coordinates along 'parallel' dimensions, in Jax/on device.
If randomization of grid dimensions with dynamic sizes (i.e. sizes not known at Jax trace time) were allowed, this would require computing these randomizations on the host/on CPU (where one can have arrays of dynamic shape).

PiperOrigin-RevId: 746365669
2025-04-11 01:54:11 -07:00
jax authors
7b7d36a8e6 Add a 2D test in memories_test.
PiperOrigin-RevId: 746295338
2025-04-10 21:32:56 -07:00
Ayaka
9f5f6edb85 [Pallas] Fix integer array indexing
Fixes https://github.com/google/jax/issues/22783

jax-fixit

PiperOrigin-RevId: 746260869
2025-04-10 19:10:35 -07:00
jax authors
c5d6a19997 Merge pull request #27938 from hawkinsp:scipy
PiperOrigin-RevId: 746257919
2025-04-10 19:00:34 -07:00
Peter Hawkins
ffc33abb5d Bump scipy build requirement on Python 3.13.
We need v1.15.2 for Linux aarch64 3.13-t wheels to exist.
2025-04-11 01:41:31 +00:00
jax authors
907725dfd7 Merge pull request #27937 from mattjj:while-readonly-carry-optimization
PiperOrigin-RevId: 746250385
2025-04-10 18:29:49 -07:00
Matthew Johnson
6e52b1e95b optimize while_loop by moving readonly carry components to be consts
also fix a bug in ordered effects in cond_fun lowering

fixes google/flax#4700
2025-04-11 00:48:52 +00:00
Tomás Longeri
6d57f00b58 [Mosaic:TPU][Relayout] Add implicit 2nd minor
PiperOrigin-RevId: 746228503
2025-04-10 17:04:17 -07:00
Peter Hawkins
b352763a17 Fix Pallas tests so they work with JAX_TEST_NUM_THREADS >= 1.
PiperOrigin-RevId: 746226562
2025-04-10 16:57:34 -07:00
Nitin Srinivasan
b73bf1a03a Update JAX continuous workflow to run once every 3 hours instead of 2.
We are seeing a higher number of cancellations of the continuous job recently:
```
Canceling since a higher priority waiting request for CI - Wheel Tests (Continuous)-refs/heads/main exists
```

PiperOrigin-RevId: 746222323
2025-04-10 16:43:45 -07:00