16891 Commits

Author SHA1 Message Date
Chris Jones
1926b99bfd [pallas] Fix spelling of 'fusible'.
PiperOrigin-RevId: 747663692
2025-04-14 19:35:59 -07:00
Mark Sandler
0ed0fb7c54 Adds a debugging message to assert, otherwise the error is pretty cryptic.
PiperOrigin-RevId: 747657234
2025-04-14 19:11:02 -07:00
Sharad Vikram
4fa3cd91d3 [Pallas/Fuser] Add basic closed over consts support to pull_block_spec
PiperOrigin-RevId: 747657069
2025-04-14 19:09:04 -07:00
Peter Hawkins
57e33bcbcd Deprecate the contents of jax.util.
PiperOrigin-RevId: 747629222
2025-04-14 17:20:30 -07:00
Ivy Zheng
ab600c3e82 Remove obsolete python key path registry.
PiperOrigin-RevId: 747613761
2025-04-14 16:33:05 -07:00
jax authors
19be20fc6f Merge pull request #27919 from kaixih:enable_doc_scaled_dot_fix
PiperOrigin-RevId: 747578845
2025-04-14 14:55:23 -07:00
Peter Hawkins
8930a67e63 Fix stablehlo version comparison in test utilities.
PiperOrigin-RevId: 747547427
2025-04-14 13:34:32 -07:00
jax authors
d014912671 Merge pull request #28007 from jakevdp:int-power
PiperOrigin-RevId: 747498460
2025-04-14 11:26:05 -07:00
jax authors
6fcb036b96 Merge pull request #27966 from jakevdp:jit-signature
PiperOrigin-RevId: 747492659
2025-04-14 11:11:02 -07:00
Jake VanderPlas
42542feac6 jnp.power: better docs for invalid input 2025-04-14 10:42:29 -07:00
jax authors
30669dc219 Merge pull request #27993 from gnecula:explain_timing
PiperOrigin-RevId: 747480248
2025-04-14 10:41:05 -07:00
Jake VanderPlas
ceca6ec1fc jax.jit: deprecate non-standard call signature. 2025-04-14 10:13:05 -07:00
Dan Foreman-Mackey
1b1bd071bc Finalize deprecation of vectorized argument in callbacks.
The `vectorized` argument to `pure_callback` and `ffi_call` was deprecated in JAX v0.4.34 (released Oct 4 2024), then added to the CHANGELOG in v0.4.35 (doh! released Oct 22). The JAX compatibility policy requires 3 months of compatible releases before a deprecation is finalized, so it is time to remove this parameter from the public API. The `vmap_method` parameter can be used instead, and the docs for [`pure_callback`](https://docs.jax.dev/en/latest/_autosummary/jax.pure_callback.html) provide more details.

This change has one other (non-obvious!) affect on the user facing APIs. (Note that this change in behavior has also been protected by a deprecation warning since the `vectorized` parameter was deprecated.) The default behavior of `pure_callback` and `ffi_call` under `vmap` is to now raise an exception, rather than silently producing a loop. To opt in to the previous default behavior, use `vmap_method="sequential"`.

PiperOrigin-RevId: 747413383
2025-04-14 07:43:59 -07:00
jax authors
b6c6c1c258 Merge pull request #27971 from ywrt:patch-1
PiperOrigin-RevId: 747399343
2025-04-14 07:00:10 -07:00
George Necula
b8df474965 [explain_cache_miss] Add to explanations the duration of the missed function call
This enables the user to focus on the most important
call sites.

jax-fixit
2025-04-14 16:08:24 +03:00
jax authors
6ca623f79b Merge pull request #27980 from gnecula:tracing_cache
PiperOrigin-RevId: 747274185
2025-04-13 23:53:16 -07:00
George Necula
f070cdecb3 [explain-cache-miss] Improve tracing-cache-miss explanations
The previous approach was to report, for several elements
of the cache key, the closest mismatch. Some parts of
the cache key were ignored, which led to "explanation unavailable".
The same happened when we had two keys close to the current
one, each differring in a different part of the key.
No explanation was produced because for each part of the key,
there was a matching key already in the cache, even though
the key taken as a whole did not match.

Now, we scan *all* parts of they key and compute the differences.
We keep track of the "size" of the differences, and we explain
the differences to those keys that are closest (possibly more
than one key if equidistant).
For example, for shape differences we'll report the
closest matching shape. If a type differs in both the dtype
and some parts of the shape, or sharding, it is considered
farther away.

We add new tests and explanations for  different
static argnums and argnames.

There are still cases when we do not produce an explanation, but
now the "explanation unavailable" includes a description
of which component of the key is different, and what the
difference is. This may still be hard to understand by the
user but at least they can file a clearer bug.

Refactored the tests, and added a few new ones.
2025-04-13 20:44:46 +03:00
Peter Hawkins
c69e61e1a9 Remove jax.lib.xla_client.{XlaComputation,Shape}.
PiperOrigin-RevId: 746803082
2025-04-12 06:18:02 -07:00
Roy Frostig
566d0775a8 unify stages.Lowering and stages.XlaLowering
We no longer have many different implicit types conforming to `Lowering`, only `pxla.MeshComputation` and `pxla.PmapComputation`. Both are `XlaLowering` subtypes. So define just one common base class, call it `Lowering`, and inherit from just that in both concrete internal computation/lowering subtypes.

PiperOrigin-RevId: 746735857
2025-04-12 00:31:14 -07:00
Roy Frostig
99ca14601d revert making Executable an ABC
PiperOrigin-RevId: 746726071
2025-04-11 23:49:25 -07:00
Yash Katariya
4ff78e6a0e Remove various methods from MeshExecutable
These are thin and their implementations can be inlined directly at call sites in `XlaExecutable`.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 746716734
2025-04-11 23:02:54 -07:00
Roy Frostig
19d3d954bf unify stages.Executable and stages.XlaExecutable
We no longer have many different implicit types conforming to `Executable`, only `pxla.MeshExectuable` and `pxla.PmapExecutable`. Both are `XlaExecutable` subtypes. So define just one common base class, call it `Exectuable`, and inherit from just that in both concrete internal executable subtypes.

PiperOrigin-RevId: 746706712
2025-04-11 22:09:47 -07:00
George Necula
dc10200906 [explain-cache-miss] Improve the detection of user file names
When we print explanations for tracing cache misses,
we use traceback_util to ignore JAX-internal functions.
Here we change the detection mechanism to use
source_info_util, which has a more exhaustive
list of JAX internals.

This removes a lot of uninteresting explanations
from a large benchmark.

jax-fixit

PiperOrigin-RevId: 746703003
2025-04-11 21:53:55 -07:00
jax authors
e1cad34522 Add ChunkedCausalMask for Splash Attention to support attention masking similar to Llama4.
Llama4 uses (interleaved) chunk attention to support long context.

PiperOrigin-RevId: 746661156
2025-04-11 18:59:07 -07:00
Yash Katariya
8afc833c24 Rename is_closed to is_open in the shardy shardings
PiperOrigin-RevId: 746645422
2025-04-11 17:42:34 -07:00
jax authors
1a4a86aa48 Merge pull request #27970 from mattjj:while-readonly-carry-optimization
PiperOrigin-RevId: 746639385
2025-04-11 17:15:30 -07:00
Matthew Johnson
29f65f04ed re-index jaxpr input effects in move_binders_to_front 2025-04-11 23:50:14 +00:00
Jevin Jiang
0fa732ea45 [ragged-paged-attn][NFC] Make validate_inputs functions take same inputs as attention call.
PiperOrigin-RevId: 746616128
2025-04-11 15:49:53 -07:00
Matthew Johnson
b3f49e42d9 Re-landing #27937 with fewer bugs and more tests. 2025-04-11 22:42:08 +00:00
jax authors
b2a8df7183 Add the method argument to jax.numpy.isin stub.
This parameter is available from https://github.com/google/jax/pull/23040 and documented in https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isin.html.

PiperOrigin-RevId: 746606206
2025-04-11 15:15:22 -07:00
Peter Hawkins
6fc78a5a6d Deprecate jax.lax.infeed and jax.lax.outfeed.
These APIs are already broken on GPU and TPU by virtue of not being implemented in the PJRT C API, so it seems unlikely that they have any users.

PiperOrigin-RevId: 746595857
2025-04-11 14:42:14 -07:00
ywrt
c90751bc54
Fix typo in jax.lax.linalg.symmetric_product description
Missing space in '..math::' meant that the math wasn't rendering correctly.
2025-04-12 07:20:39 +10:00
Yash Katariya
6efcf44b1a Deprecate PositionalSharding and GSPMDSharding
PiperOrigin-RevId: 746564071
2025-04-11 13:06:43 -07:00
Matthew Johnson
e9364f4b0a Reverts 907725dfd7a7fb612c4f6d975bb462f1ae1a21d7
PiperOrigin-RevId: 746554582
2025-04-11 12:37:20 -07:00
Peter Hawkins
904419cb0e Rename TPU bazel test tags.
Use a count of chips (or omit it if 1) rather than specifying an ICI topology.

Examples:
* tpu_v5e_1x1 -> tpu_v5e
* tpu_v5e_4x2 -> tpu_v5e_x8
PiperOrigin-RevId: 746547477
2025-04-11 12:15:15 -07:00
Justin Fu
27c07f7cd3 [Pallas] Allow 1D iota
PiperOrigin-RevId: 746546870
2025-04-11 12:13:33 -07:00
Peter Hawkins
ab88273596 Deprecate jax.dlpack.to_dlpack.
This is not needed under the newer DLPack protocol for users, and there's an equivalent (`__dlpack__`).

PiperOrigin-RevId: 746530351
2025-04-11 11:26:20 -07:00
Yash Katariya
a39b6232be Make sure the order passed to make_jit and _parse_jit_arguments is the same as the order of arguments received in jit API and make it keyword-only
PiperOrigin-RevId: 746527807
2025-04-11 11:18:59 -07:00
George Necula
5adac1cb8a Fix the printing of the function name in tracing-cache-miss explanations
jax-fixit

PiperOrigin-RevId: 746496570
2025-04-11 09:53:57 -07:00
Peter Hawkins
3736e5ba85 Bump the JAX version to v0.6.0, which will be the next release version.
PiperOrigin-RevId: 746490665
2025-04-11 09:34:42 -07:00
Peter Hawkins
8b7319afe9 [JAX] Remove calls to jax.dlpack.to_dlpack(), and avoid passing DLPack capsules to jax.dlpack.from_dlpack().
to_dlpack() is not needed in the current version of the dlpack protocol. The from_dlpack() method accepts an object that implements __dlpack__(). In most cases, a JAX array can be passed directly to functions like torch.dlpack.from_dlpack(), and vice versa for other frameworks. The main exception is TensorFlow which does not implement the current protocol.

PiperOrigin-RevId: 746464890
2025-04-11 08:09:41 -07:00
Sergei Lebedev
d543df1324 [pallas:mosaic_gpu] Added support for unroll=True to the lax.fori_loop lowering
PiperOrigin-RevId: 746444372
2025-04-11 06:56:05 -07:00
Peter Hawkins
b49972d1ce Move test skip for unary_ops_accuracy_test to a setUp method.
The skip decorator being used here only worked for test methods, not test classes, so it accidentally had the effect of skipping all the tests.
But we don't really need a special decorator here anyway.

PiperOrigin-RevId: 746434607
2025-04-11 06:19:45 -07:00
George Necula
7eb397d1e5 Make trace and lower class attributes for jax.jit.
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use:

```
jax.jit(f).trace(...)
```

The new attributes create problems when `jax.jit` is used along `functools.wraps`.
Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a
function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`.
This works as expected if you just call the result, but if you try to use it with
`lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the
attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function,
so when `trace` is invoked the `wrapper` is bypassed entirely.

See #27829 and #27825.

The solution proposed here is to make the `trace` and `lower` be class attributes,
so that they are not copied by `functools.wraps`.
Thus, if you try to use `lower` or `trace` on the result of
`functools.wraps(jax.jit(f))()` you will get an error.
That is better than silently ignoring the wrapper.
The workaround is to apply `jax.jit` last among your wrappers.

Fixes: #27829
2025-04-11 14:51:12 +03:00
jax authors
c9cbf82164 Merge pull request #27876 from gnecula:aot_compute_on
PiperOrigin-RevId: 746402180
2025-04-11 04:08:18 -07:00
jax authors
1035c9a118 Merge pull request #27916 from gnecula:tracing_cache_ignore_internals
PiperOrigin-RevId: 746397452
2025-04-11 03:53:47 -07:00
jax authors
ac285a138b Merge pull request #27685 from Cjkkkk:return_cudnn_sdpa_residual
PiperOrigin-RevId: 746397395
2025-04-11 03:51:40 -07:00
Dan Foreman-Mackey
81722201fd Remove legacy CPU custom call kernels that have been unused since v0.4.34.
As of today it has been 180 days since the release of 0.4.34 where the following legacy LAPACK kernels were no longer used when lowering:

* getrf
* geqrf / orgqr
* potrf
* gesdd
* syevd
* geev
* gehrd

Following our compatibility policy, these are now safe to remove.

PiperOrigin-RevId: 746388529
2025-04-11 03:17:19 -07:00
George Necula
96d38a6b66 [cache_misses] Skip tracing-cache-miss explanations for JAX internal functions
About half of the tracing-cache-miss explanations in a large benchmark
end up being from JAX-internal functions, such as `jax.numpy` functions.
These cache misses are not what the JAX user wants to see, so we filter
them out, using the same mechanism used for filtering tracebacks.
2025-04-11 12:53:38 +03:00
jax authors
d42d2e88b4 [Pallas] Interpret dimensions with parallel semantics by traversing the corresponding grid coordinates in randomized order.
Note that dynamic grid dimensions with 'parallel' semantics are disallowed. This enables the computation of grid points, with randomized coordinates along 'parallel' dimensions, in Jax/on device.
If randomization of grid dimensions with dynamic sizes (i.e. sizes not known at Jax trace time) were allowed, this would require computing these randomizations on the host/on CPU (where one can have arrays of dynamic shape).

PiperOrigin-RevId: 746365669
2025-04-11 01:54:11 -07:00