9119 Commits

Author SHA1 Message Date
Adam Paszke
7e2148b800 [Pallas:MGPU] Don't assume we'll be running at least max_concurrent_steps in the memory WG
max_concurrent_steps is an upper bound: we no longer guarantee that it accurately
reflects the actual number of steps when the grid has dynamic bounds

PiperOrigin-RevId: 746036125
2025-04-10 08:12:46 -07:00
jax authors
95f1207fbf Merge pull request #27843 from dfm:lin-call-jvp
PiperOrigin-RevId: 746026464
2025-04-10 07:43:50 -07:00
Adam Paszke
ec59178d29 [Pallas:MGPU] Make sure to await all arrivals on consumed barriers
Without this, `emit_pipeline_warp_specialized` would leave the barriers in a bad
state, causing deadlocks or crashes when it was called multiple times in sequence.

PiperOrigin-RevId: 746022784
2025-04-10 07:32:30 -07:00
Dan Foreman-Mackey
e1aa83ad67 Add JVP rule for linear_call. 2025-04-10 09:12:01 -04:00
Zac Cranko
8f9f1aa35a add sphinx extension and placeholder config docs rst
improve layout, information

add dummy import to hopefully fix build issue

parse help text for markdown

whoops didn't mean to do it twice

jax prefix text no longer applies here

two space indents

address definition list ending without blank line error

provide deprecation mechanism

document context managagers if they exist

remove mention of context manager

try and fix formatting

improve formatting, fail to fix warnings

fail to fix bug, make better looking anyway

okay bug was in the parsing of help text to rst, some of which does not parse

wow, found the bug, turns out help strings were not valid rst
2025-04-10 05:55:10 -07:00
jax authors
f7a2760822 Merge pull request #27831 from dfm:linear-call-recursion
PiperOrigin-RevId: 745992513
2025-04-10 05:50:34 -07:00
Peter Hawkins
b4c3e38022 When running test cases concurrently, log the start and end of each test case.
This is very helpful for debugging deadlocks!

PiperOrigin-RevId: 745986596
2025-04-10 05:26:02 -07:00
Peter Hawkins
382285d315 Split JaxTestLoader and related classes into a separate file.
Refactoring only, no functional changes intended.

PiperOrigin-RevId: 745813442
2025-04-09 18:45:45 -07:00
Matthew Johnson
892cb65308 [shard-map] good errors for pvary issues
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2025-04-10 01:25:49 +00:00
Yash Katariya
75e4279e32 Set jax_varying_axes_in_types to True by default.
PiperOrigin-RevId: 745739477
2025-04-09 14:40:31 -07:00
Kevin Gleason
7c1595ac61 Skip jax/tests:unary_ops_accuracy_test when running with older versions of StableHLO.
PiperOrigin-RevId: 745717137
2025-04-09 13:42:25 -07:00
Arno Eigenwillig
e772a08ac2 Fix docstrings of segment_{prod,max,min} after commit 4679f45.
PiperOrigin-RevId: 745707421
2025-04-09 13:17:00 -07:00
jax authors
c418495b13 Merge pull request #27886 from mattjj:26763
PiperOrigin-RevId: 745704582
2025-04-09 13:09:11 -07:00
jax authors
2863b48801 Merge pull request #27759 from mattjj:vmappable-bind-fix
PiperOrigin-RevId: 745692664
2025-04-09 12:37:03 -07:00
Matthew Johnson
2b3839d248 [shard-map] make shard_map work with custom_jvp symbolic zeros
also resolves a TODO(mattjj,frostig) from #14570 to make vmap-of-custom_jvp not instantiate symbolic zeros

fixes #26763
2025-04-09 19:28:37 +00:00
jax authors
038566713b Merge pull request #27853 from carlosgmartin:merge_tuple_update_tuple_replace
PiperOrigin-RevId: 745683630
2025-04-09 12:11:52 -07:00
Daniel Suo
e750d7ead6 Add option for debug print to be called on partitioned arguments rather than all-gather and print the entire logical arguments.
PiperOrigin-RevId: 745678930
2025-04-09 12:00:00 -07:00
carlosgmartin
b6a46310d1 Merge tuple_replace and tuple_update in jax._src.util. 2025-04-09 12:50:42 -04:00
Yash Katariya
76c6b5b00d More changes for enabling vma by default in JAX
PiperOrigin-RevId: 745621495
2025-04-09 09:33:33 -07:00
Dan Foreman-Mackey
e0cda84d90 Fix linear_call to allow recursive definitions. 2025-04-09 09:45:32 -04:00
Dimitar (Mitko) Asenov
9adc3cc05c [Mosaic GPU] Add a LayoutCast op to the Mosaic GPU mlir dialect.
PiperOrigin-RevId: 745557539
2025-04-09 06:11:15 -07:00
Adam Paszke
6792703dbe Fix failing documentation tests
The CUDA-specific primitives need to be explicitly skipped.

PiperOrigin-RevId: 745504040
2025-04-09 02:53:04 -07:00
Sergei Lebedev
866e32b329 [pallas:mosaic_gpu] ModuleContext.reserve_barrier is now a context manager
This allows unreserving the barrier once it is no longer needed and is consistent
with how resource estimation works, e.g. for `cond`.

PiperOrigin-RevId: 745483567
2025-04-09 01:45:43 -07:00
Yash Katariya
f95f6a8bdb Rename psum2 to psum_invariant and put it in lax_parallel. We shouldn't expose this to public API and have users use psum instead which will dispatch to psum_invariant when check_rep=True.
PiperOrigin-RevId: 745352875
2025-04-08 17:28:59 -07:00
Yash Katariya
84016bc368 Rename pbroadcast to pvary and expose it as jax.lax.pvary.
PiperOrigin-RevId: 745342103
2025-04-08 16:51:27 -07:00
jax authors
373ac2ef7e Merge pull request #27804 from vfdev-5:ft-adapt-state-test-2
PiperOrigin-RevId: 745341315
2025-04-08 16:48:54 -07:00
jax authors
b8d9e7f427 Merge pull request #27503 from kaixih:enable_doc_scaled_dot
PiperOrigin-RevId: 745322012
2025-04-08 15:50:54 -07:00
cjkkkk
d19a458f32 fix docstring 2025-04-08 22:48:47 +00:00
vfdev-5
5a340a9781 Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true
Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
  - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
2025-04-08 21:02:55 +00:00
Yash Katariya
8301c304c1 Make changes to shard_map to prepare for setting varying_axes_in_types to True.
The main changes here are:

* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.

* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.

* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.

* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
2025-04-08 13:47:13 -07:00
Peter Hawkins
b4629c230c Split weakref_lru_cache into its own extension.
Now that db11efab3b has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA.

There's no reason weakref_lru_cache is in the same Python extension as everything else.

PiperOrigin-RevId: 745271825
2025-04-08 13:34:47 -07:00
Dan Foreman-Mackey
2d44f985c3 Finalize deprecation of ffi_call with inline arguments.
PiperOrigin-RevId: 745261995
2025-04-08 13:09:42 -07:00
jax authors
f1bcf3bb94 Merge pull request #27821 from j-towns:clarify-make-jaxpr-docstr
PiperOrigin-RevId: 745216259
2025-04-08 11:11:57 -07:00
jax authors
76825a2d45 Merge pull request #27807 from jakevdp:eigvalsh-symmetrize
PiperOrigin-RevId: 745216021
2025-04-08 11:09:58 -07:00
jax authors
b073e8db25 Merge pull request #27836 from jakevdp:fix-repeat
PiperOrigin-RevId: 745215941
2025-04-08 11:07:57 -07:00
jax authors
03c1bf96c6 Merge pull request #27803 from mattjj:27644
PiperOrigin-RevId: 745201720
2025-04-08 10:34:01 -07:00
Jake VanderPlas
b7d430f96b jnp.repeat: don't cast repeats to array, as they must be static. 2025-04-08 10:32:03 -07:00
Matthew Johnson
4d2808c115 [mutable-arrays] limit implicit ref_swap dtype promotion
fixes #27683

In b7715e279, specifically this line:

b7715e279d (diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193)

we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in #27683. A repro looks like:

```python
import jax.numpy as jnp
from jax._src import core

v = core.mutable_array(jnp.array([0, 0, 0]))
v[...] += 1.0
print(v)  # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32)
```

We can't easily just drop this behavior because it seems many GPU x64 tests depend on it.

So in this change we're trying to
1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype;
2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and
3. do an ordinary cast rather than a bitcast.

I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g.

```python
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0  # don't error!
```

But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment.

PiperOrigin-RevId: 745198669
2025-04-08 10:25:10 -07:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Sergei Lebedev
f5d73b89ca [pallas:mosaic_gpu] Added test for custom pretty-printing rules
PiperOrigin-RevId: 745145207
2025-04-08 07:59:44 -07:00
Adam Paszke
b926fac66e [Mosaic GPU] Simplify load/store methods now that we have fewer layouts
PiperOrigin-RevId: 745139008
2025-04-08 07:39:53 -07:00
jax authors
73ecf0bb48 Remove unused return wrapper in annotate_function that creates a self reference cycle loop in python.
PiperOrigin-RevId: 745099538
2025-04-08 05:25:23 -07:00
Sergei Lebedev
5f33280ded [pallas:mosaic_gpu] emit_pipeline* now allows the grid to be dynamic
PiperOrigin-RevId: 745091128
2025-04-08 04:55:15 -07:00
Sergei Lebedev
12811f08a8 Removed eager_pmap config option
It defaults to True and is not flipped to False by any internal JAX users.

PiperOrigin-RevId: 745067361
2025-04-08 03:30:36 -07:00
Dimitar (Mitko) Asenov
c4cc94a10c [Mosaic GPU] Add warpgroup lowering for RunState in Pallas.
After this change we no longer skip tests that required 'RunState`. This necessitated a small fix in the pallas lowering of `while` and also enabling multiple i32 register bundling in the `optimization_barrier` lowering.

PiperOrigin-RevId: 745065173
2025-04-08 03:23:20 -07:00
Sergei Lebedev
af072feb5a Removed redundant passes
If a function or class has a docstring, it does not need a `pass`.

PiperOrigin-RevId: 745052107
2025-04-08 02:38:21 -07:00
Sergei Lebedev
8ed59d8b5d Removed jax._src.raise_to_shaped
It is just an identity after the "stackless" rewrite.

PiperOrigin-RevId: 745042532
2025-04-08 02:06:40 -07:00
jax authors
c2eaedfe94 Merge pull request #27776 from gnecula:export_keys
PiperOrigin-RevId: 745038060
2025-04-08 01:53:11 -07:00
Jamie Townsend
bc11a63113 Clarify jax.make_jaxpr docstring 2025-04-08 09:50:31 +02:00
George Necula
51dbcd4dad [export] Add backwards compatibility test for annotate_device_placement.
This enables exporting functions that use memory kinds to place
data in different memories.

jax-fixit

PiperOrigin-RevId: 745008959
2025-04-08 00:10:23 -07:00