9119 Commits

Author SHA1 Message Date
Jake VanderPlas
42542feac6 jnp.power: better docs for invalid input 2025-04-14 10:42:29 -07:00
jax authors
30669dc219 Merge pull request #27993 from gnecula:explain_timing
PiperOrigin-RevId: 747480248
2025-04-14 10:41:05 -07:00
Jake VanderPlas
ceca6ec1fc jax.jit: deprecate non-standard call signature. 2025-04-14 10:13:05 -07:00
Dan Foreman-Mackey
1b1bd071bc Finalize deprecation of vectorized argument in callbacks.
The `vectorized` argument to `pure_callback` and `ffi_call` was deprecated in JAX v0.4.34 (released Oct 4 2024), then added to the CHANGELOG in v0.4.35 (doh! released Oct 22). The JAX compatibility policy requires 3 months of compatible releases before a deprecation is finalized, so it is time to remove this parameter from the public API. The `vmap_method` parameter can be used instead, and the docs for [`pure_callback`](https://docs.jax.dev/en/latest/_autosummary/jax.pure_callback.html) provide more details.

This change has one other (non-obvious!) affect on the user facing APIs. (Note that this change in behavior has also been protected by a deprecation warning since the `vectorized` parameter was deprecated.) The default behavior of `pure_callback` and `ffi_call` under `vmap` is to now raise an exception, rather than silently producing a loop. To opt in to the previous default behavior, use `vmap_method="sequential"`.

PiperOrigin-RevId: 747413383
2025-04-14 07:43:59 -07:00
jax authors
b6c6c1c258 Merge pull request #27971 from ywrt:patch-1
PiperOrigin-RevId: 747399343
2025-04-14 07:00:10 -07:00
George Necula
b8df474965 [explain_cache_miss] Add to explanations the duration of the missed function call
This enables the user to focus on the most important
call sites.

jax-fixit
2025-04-14 16:08:24 +03:00
jax authors
6ca623f79b Merge pull request #27980 from gnecula:tracing_cache
PiperOrigin-RevId: 747274185
2025-04-13 23:53:16 -07:00
carlosgmartin
2336cd1695 Minor improvements to doc for jax.nn.logsumexp. 2025-04-13 15:17:11 -04:00
George Necula
f070cdecb3 [explain-cache-miss] Improve tracing-cache-miss explanations
The previous approach was to report, for several elements
of the cache key, the closest mismatch. Some parts of
the cache key were ignored, which led to "explanation unavailable".
The same happened when we had two keys close to the current
one, each differring in a different part of the key.
No explanation was produced because for each part of the key,
there was a matching key already in the cache, even though
the key taken as a whole did not match.

Now, we scan *all* parts of they key and compute the differences.
We keep track of the "size" of the differences, and we explain
the differences to those keys that are closest (possibly more
than one key if equidistant).
For example, for shape differences we'll report the
closest matching shape. If a type differs in both the dtype
and some parts of the shape, or sharding, it is considered
farther away.

We add new tests and explanations for  different
static argnums and argnames.

There are still cases when we do not produce an explanation, but
now the "explanation unavailable" includes a description
of which component of the key is different, and what the
difference is. This may still be hard to understand by the
user but at least they can file a clearer bug.

Refactored the tests, and added a few new ones.
2025-04-13 20:44:46 +03:00
Roy Frostig
566d0775a8 unify stages.Lowering and stages.XlaLowering
We no longer have many different implicit types conforming to `Lowering`, only `pxla.MeshComputation` and `pxla.PmapComputation`. Both are `XlaLowering` subtypes. So define just one common base class, call it `Lowering`, and inherit from just that in both concrete internal computation/lowering subtypes.

PiperOrigin-RevId: 746735857
2025-04-12 00:31:14 -07:00
Roy Frostig
99ca14601d revert making Executable an ABC
PiperOrigin-RevId: 746726071
2025-04-11 23:49:25 -07:00
Yash Katariya
4ff78e6a0e Remove various methods from MeshExecutable
These are thin and their implementations can be inlined directly at call sites in `XlaExecutable`.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 746716734
2025-04-11 23:02:54 -07:00
Roy Frostig
19d3d954bf unify stages.Executable and stages.XlaExecutable
We no longer have many different implicit types conforming to `Executable`, only `pxla.MeshExectuable` and `pxla.PmapExecutable`. Both are `XlaExecutable` subtypes. So define just one common base class, call it `Exectuable`, and inherit from just that in both concrete internal executable subtypes.

PiperOrigin-RevId: 746706712
2025-04-11 22:09:47 -07:00
George Necula
dc10200906 [explain-cache-miss] Improve the detection of user file names
When we print explanations for tracing cache misses,
we use traceback_util to ignore JAX-internal functions.
Here we change the detection mechanism to use
source_info_util, which has a more exhaustive
list of JAX internals.

This removes a lot of uninteresting explanations
from a large benchmark.

jax-fixit

PiperOrigin-RevId: 746703003
2025-04-11 21:53:55 -07:00
Yash Katariya
8afc833c24 Rename is_closed to is_open in the shardy shardings
PiperOrigin-RevId: 746645422
2025-04-11 17:42:34 -07:00
Matthew Johnson
29f65f04ed re-index jaxpr input effects in move_binders_to_front 2025-04-11 23:50:14 +00:00
Matthew Johnson
b3f49e42d9 Re-landing #27937 with fewer bugs and more tests. 2025-04-11 22:42:08 +00:00
ywrt
c90751bc54
Fix typo in jax.lax.linalg.symmetric_product description
Missing space in '..math::' meant that the math wasn't rendering correctly.
2025-04-12 07:20:39 +10:00
Yash Katariya
6efcf44b1a Deprecate PositionalSharding and GSPMDSharding
PiperOrigin-RevId: 746564071
2025-04-11 13:06:43 -07:00
Matthew Johnson
e9364f4b0a Reverts 907725dfd7a7fb612c4f6d975bb462f1ae1a21d7
PiperOrigin-RevId: 746554582
2025-04-11 12:37:20 -07:00
Justin Fu
27c07f7cd3 [Pallas] Allow 1D iota
PiperOrigin-RevId: 746546870
2025-04-11 12:13:33 -07:00
Yash Katariya
a39b6232be Make sure the order passed to make_jit and _parse_jit_arguments is the same as the order of arguments received in jit API and make it keyword-only
PiperOrigin-RevId: 746527807
2025-04-11 11:18:59 -07:00
George Necula
5adac1cb8a Fix the printing of the function name in tracing-cache-miss explanations
jax-fixit

PiperOrigin-RevId: 746496570
2025-04-11 09:53:57 -07:00
Sergei Lebedev
d543df1324 [pallas:mosaic_gpu] Added support for unroll=True to the lax.fori_loop lowering
PiperOrigin-RevId: 746444372
2025-04-11 06:56:05 -07:00
Peter Hawkins
b49972d1ce Move test skip for unary_ops_accuracy_test to a setUp method.
The skip decorator being used here only worked for test methods, not test classes, so it accidentally had the effect of skipping all the tests.
But we don't really need a special decorator here anyway.

PiperOrigin-RevId: 746434607
2025-04-11 06:19:45 -07:00
George Necula
7eb397d1e5 Make trace and lower class attributes for jax.jit.
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use:

```
jax.jit(f).trace(...)
```

The new attributes create problems when `jax.jit` is used along `functools.wraps`.
Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a
function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`.
This works as expected if you just call the result, but if you try to use it with
`lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the
attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function,
so when `trace` is invoked the `wrapper` is bypassed entirely.

See #27829 and #27825.

The solution proposed here is to make the `trace` and `lower` be class attributes,
so that they are not copied by `functools.wraps`.
Thus, if you try to use `lower` or `trace` on the result of
`functools.wraps(jax.jit(f))()` you will get an error.
That is better than silently ignoring the wrapper.
The workaround is to apply `jax.jit` last among your wrappers.

Fixes: #27829
2025-04-11 14:51:12 +03:00
jax authors
c9cbf82164 Merge pull request #27876 from gnecula:aot_compute_on
PiperOrigin-RevId: 746402180
2025-04-11 04:08:18 -07:00
jax authors
1035c9a118 Merge pull request #27916 from gnecula:tracing_cache_ignore_internals
PiperOrigin-RevId: 746397452
2025-04-11 03:53:47 -07:00
jax authors
ac285a138b Merge pull request #27685 from Cjkkkk:return_cudnn_sdpa_residual
PiperOrigin-RevId: 746397395
2025-04-11 03:51:40 -07:00
Dan Foreman-Mackey
81722201fd Remove legacy CPU custom call kernels that have been unused since v0.4.34.
As of today it has been 180 days since the release of 0.4.34 where the following legacy LAPACK kernels were no longer used when lowering:

* getrf
* geqrf / orgqr
* potrf
* gesdd
* syevd
* geev
* gehrd

Following our compatibility policy, these are now safe to remove.

PiperOrigin-RevId: 746388529
2025-04-11 03:17:19 -07:00
George Necula
96d38a6b66 [cache_misses] Skip tracing-cache-miss explanations for JAX internal functions
About half of the tracing-cache-miss explanations in a large benchmark
end up being from JAX-internal functions, such as `jax.numpy` functions.
These cache misses are not what the JAX user wants to see, so we filter
them out, using the same mechanism used for filtering tracebacks.
2025-04-11 12:53:38 +03:00
jax authors
d42d2e88b4 [Pallas] Interpret dimensions with parallel semantics by traversing the corresponding grid coordinates in randomized order.
Note that dynamic grid dimensions with 'parallel' semantics are disallowed. This enables the computation of grid points, with randomized coordinates along 'parallel' dimensions, in Jax/on device.
If randomization of grid dimensions with dynamic sizes (i.e. sizes not known at Jax trace time) were allowed, this would require computing these randomizations on the host/on CPU (where one can have arrays of dynamic shape).

PiperOrigin-RevId: 746365669
2025-04-11 01:54:11 -07:00
Ayaka
9f5f6edb85 [Pallas] Fix integer array indexing
Fixes https://github.com/google/jax/issues/22783

jax-fixit

PiperOrigin-RevId: 746260869
2025-04-10 19:10:35 -07:00
jax authors
907725dfd7 Merge pull request #27937 from mattjj:while-readonly-carry-optimization
PiperOrigin-RevId: 746250385
2025-04-10 18:29:49 -07:00
Matthew Johnson
6e52b1e95b optimize while_loop by moving readonly carry components to be consts
also fix a bug in ordered effects in cond_fun lowering

fixes google/flax#4700
2025-04-11 00:48:52 +00:00
Peter Hawkins
b352763a17 Fix Pallas tests so they work with JAX_TEST_NUM_THREADS >= 1.
PiperOrigin-RevId: 746226562
2025-04-10 16:57:34 -07:00
Christos Perivolaropoulos
41a8805d96 [pallas:mgpu] Return types allowed in mgpu.inline_mgpu.
PiperOrigin-RevId: 746217405
2025-04-10 16:28:34 -07:00
Justin Fu
92be510f0b [Mosaic GPU] Implement warp-level thread semantics.
Adds a new WarpMesh object which when used in conjunction with core_map, allows the user to drop into warp-level code rather than programming at the warpgroup level.

PiperOrigin-RevId: 746163942
2025-04-10 13:59:22 -07:00
Parker Schuh
48e14dcc0c Implement mutation by replacing the contents of a jax.Array with a result
jax.Array.

PiperOrigin-RevId: 746147571
2025-04-10 13:17:50 -07:00
Justin Fu
2807ae4e34 [Pallas] Fix ()-shaped vectors being materialized in Pallas lowering.
This fixes some non-intuitive errors where scalar-shaped values in VREGs were being used in operations that expected SREGs.

PiperOrigin-RevId: 746146037
2025-04-10 13:13:30 -07:00
kaixih
ae29f63e81 Don't use default quant config 2025-04-10 19:23:11 +00:00
jax authors
16ffbca542 Merge pull request #27849 from ZacCranko:docfig
PiperOrigin-RevId: 746098316
2025-04-10 11:06:37 -07:00
Dan Foreman-Mackey
f3115d32a2 Fix dtype failures in JaxGroupedQueryAttentionReferenceTest.
PiperOrigin-RevId: 746097962
2025-04-10 11:04:43 -07:00
kaixih
2090dadfde Deprecation warning 2025-04-10 17:45:51 +00:00
kaixih
0f29716986 One alias one 2025-04-10 17:19:53 +00:00
kaixih
a39a81ae7a Keep old scale_matmul arg names 2025-04-10 17:03:43 +00:00
jax authors
9011d66a29 Merge pull request #27903 from mattjj:pvary-errors
PiperOrigin-RevId: 746070501
2025-04-10 09:56:16 -07:00
Yash Katariya
6c0ac7a503 Do a pvary in dynamic_slice_transpose_rule so that the zeros are varying with the correct vma as the operands were.
PiperOrigin-RevId: 746065965
2025-04-10 09:43:17 -07:00
George Necula
9af0c05bbc [export] Add test that exporting works for experimental.compute_on. 2025-04-10 19:26:39 +03:00
Kostiantyn Liepieshov
c730bbda74 fix bug in export_module when no mesh axes are empty for shardy.
If mesh axes are empty, we are setting mesh as None, resulting in an error in
this test.

This fix provides an empty mesh, when no mesh axes in dumped module are empty.

PiperOrigin-RevId: 746058506
2025-04-10 09:21:58 -07:00