1815 Commits

Author SHA1 Message Date
Yash Katariya
7de522c5a3 Enter into auto mode for .at[...].get(...) a bit earlier so that all ops inside _gather are in auto mode.
Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`.

PiperOrigin-RevId: 748867490
2025-04-17 17:42:13 -07:00
Yash Katariya
06ad3528e9 Use _make_lengths_same for explicit mode too.
We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs.

PiperOrigin-RevId: 748735303
2025-04-17 10:48:46 -07:00
Sergei Lebedev
c576d328bd Added lax.axis_size and switched all existing usage of psum(1, ...) to it
PiperOrigin-RevId: 748604842
2025-04-17 02:22:25 -07:00
Yash Katariya
82215f660e Remove jax_varying_axes_in_types config and rewrite from shard_map_p
PiperOrigin-RevId: 748545142
2025-04-16 22:27:50 -07:00
Yash Katariya
5f6b99a143 Fix a bug in reduce_window sharding rule where padding is a tuple but we were checking for a scalar instead. Fixes https://github.com/jax-ml/jax/issues/28070
PiperOrigin-RevId: 748418451
2025-04-16 14:10:13 -07:00
Yash Katariya
655bfcac39 Enable standard_insert_pvary for optimization_barrier which was disabled before.
PiperOrigin-RevId: 748027360
2025-04-15 14:41:08 -07:00
Yash Katariya
6e00b5e02d [NFC] Rename standard_insert_pbroadcast to standard_insert_pvary
PiperOrigin-RevId: 747943230
2025-04-15 11:02:45 -07:00
jax authors
b6c6c1c258 Merge pull request #27971 from ywrt:patch-1
PiperOrigin-RevId: 747399343
2025-04-14 07:00:10 -07:00
Matthew Johnson
b3f49e42d9 Re-landing #27937 with fewer bugs and more tests. 2025-04-11 22:42:08 +00:00
ywrt
c90751bc54
Fix typo in jax.lax.linalg.symmetric_product description
Missing space in '..math::' meant that the math wasn't rendering correctly.
2025-04-12 07:20:39 +10:00
Matthew Johnson
e9364f4b0a Reverts 907725dfd7a7fb612c4f6d975bb462f1ae1a21d7
PiperOrigin-RevId: 746554582
2025-04-11 12:37:20 -07:00
jax authors
907725dfd7 Merge pull request #27937 from mattjj:while-readonly-carry-optimization
PiperOrigin-RevId: 746250385
2025-04-10 18:29:49 -07:00
Matthew Johnson
6e52b1e95b optimize while_loop by moving readonly carry components to be consts
also fix a bug in ordered effects in cond_fun lowering

fixes google/flax#4700
2025-04-11 00:48:52 +00:00
Dan Foreman-Mackey
f3115d32a2 Fix dtype failures in JaxGroupedQueryAttentionReferenceTest.
PiperOrigin-RevId: 746097962
2025-04-10 11:04:43 -07:00
jax authors
9011d66a29 Merge pull request #27903 from mattjj:pvary-errors
PiperOrigin-RevId: 746070501
2025-04-10 09:56:16 -07:00
Yash Katariya
6c0ac7a503 Do a pvary in dynamic_slice_transpose_rule so that the zeros are varying with the correct vma as the operands were.
PiperOrigin-RevId: 746065965
2025-04-10 09:43:17 -07:00
Adam Paszke
7e2148b800 [Pallas:MGPU] Don't assume we'll be running at least max_concurrent_steps in the memory WG
max_concurrent_steps is an upper bound: we no longer guarantee that it accurately
reflects the actual number of steps when the grid has dynamic bounds

PiperOrigin-RevId: 746036125
2025-04-10 08:12:46 -07:00
Matthew Johnson
892cb65308 [shard-map] good errors for pvary issues
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2025-04-10 01:25:49 +00:00
Yash Katariya
76c6b5b00d More changes for enabling vma by default in JAX
PiperOrigin-RevId: 745621495
2025-04-09 09:33:33 -07:00
Yash Katariya
f95f6a8bdb Rename psum2 to psum_invariant and put it in lax_parallel. We shouldn't expose this to public API and have users use psum instead which will dispatch to psum_invariant when check_rep=True.
PiperOrigin-RevId: 745352875
2025-04-08 17:28:59 -07:00
Yash Katariya
84016bc368 Rename pbroadcast to pvary and expose it as jax.lax.pvary.
PiperOrigin-RevId: 745342103
2025-04-08 16:51:27 -07:00
Yash Katariya
8301c304c1 Make changes to shard_map to prepare for setting varying_axes_in_types to True.
The main changes here are:

* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.

* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.

* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.

* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
2025-04-08 13:47:13 -07:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Rachel Han
84e04fe608 Add custom pretty print rule for the unary ops with accuracy s.t. accuracy is not printed if it's None.
PiperOrigin-RevId: 744889524
2025-04-07 16:25:01 -07:00
jax authors
48a9ad0796 Reverts 006a6a63feb64bf9984526030ba008186d69d2b4
PiperOrigin-RevId: 744864022
2025-04-07 15:08:33 -07:00
Roy Frostig
f8bbe98a86 require out_shardings as a keyword-only argument on public functions
PiperOrigin-RevId: 743753215
2025-04-03 17:26:05 -07:00
jax authors
45e6808bb5 Merge pull request #27084 from danielsuo:switch-fwd
PiperOrigin-RevId: 743452172
2025-04-03 01:07:50 -07:00
Gunhyun Park
92f7aeab48 Add simple vmap support for lax.ragged_all_to_all.
PiperOrigin-RevId: 743230485
2025-04-02 12:10:34 -07:00
Yash Katariya
3d70fc8197 Add pbroadcast insertion for psum_p in the traceable. This effectively replaces psum_p with psum2_p if varying_axes_in_types is on. psum_p can be replaced with psum2_p in follow up CLs
Also populate the aval of `ShardMapTracer` with `vma`

PiperOrigin-RevId: 743188081
2025-04-02 10:21:21 -07:00
Yash Katariya
10b2cda90e Relax the aval check in select_hlo_lowering_opaque to only check for shardings if they are not empty. The same thing happens in select_p's sharding rule
PiperOrigin-RevId: 743105350
2025-04-02 06:10:33 -07:00
Matthew Johnson
a80f6279e9 make random_gamma_grad not a primitive anymore
Fixes #16076

Co-authored-by: Roy Frostig <frostig@google.com>
2025-04-01 17:04:50 +00:00
Yash Katariya
76271d638a Add scan_p and cond_p vma rule.
PiperOrigin-RevId: 742737384
2025-04-01 09:50:38 -07:00
jax authors
006a6a63fe [Easy] Make pallas mesh grid handling more resilient to tuple names.
PiperOrigin-RevId: 742531956
2025-03-31 22:02:29 -07:00
Yash Katariya
d6b4fed5ed Propagate sharding and vma rule for axis_index_p. There's no need for pbroadcast insertion for axis_index_p in the traceable
PiperOrigin-RevId: 742334213
2025-03-31 11:33:59 -07:00
Yash Katariya
7ca50844f3 Fix an edge-case in reshape sharding rule where the last splitting/merging dim was 1.
PiperOrigin-RevId: 741740811
2025-03-28 21:43:27 -07:00
Yash Katariya
80061ad4c4 Add vma rules for pmin and pmax
PiperOrigin-RevId: 741685454
2025-03-28 16:55:16 -07:00
Matthew Johnson
6fba4ecc58 PR #27576: [attrs] experimental appendattr
Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576

This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated.

This PR also includes some fixes for getattr/setattr.
Copybara import of the project:

--
3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson <mattjj@google.com>:

[attrs] experimental appendattr

Merging this change closes #27576

COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8
PiperOrigin-RevId: 741662724
2025-03-28 15:21:12 -07:00
Yash Katariya
177193662c Add vma rules for all_gather, all_to_all, ppermute and reduce_scatter primitives
PiperOrigin-RevId: 741661360
2025-03-28 15:16:06 -07:00
jax authors
2d63b6e56d Merge pull request #27583 from jakevdp:scan-doc
PiperOrigin-RevId: 741653320
2025-03-28 14:45:52 -07:00
jax authors
6edc31ae1d Merge pull request #27525 from jakevdp:ml-dtypes-cleanup
PiperOrigin-RevId: 741651222
2025-03-28 14:38:38 -07:00
Jake VanderPlas
91dac631fb scan: improve docs & errors around dynamic length 2025-03-28 14:15:25 -07:00
Yash Katariya
5950e722e2 Make sure vma on ShapedArray exists by default to make development easier. The field is populated inside shard_map guarded on the varying_axes_in_types config though.
PiperOrigin-RevId: 741554623
2025-03-28 09:44:03 -07:00
Jake VanderPlas
431c2c0807 cleanup now that we depend on ml_dtypes>=0.5 2025-03-28 07:44:38 -07:00
Yash Katariya
563c3e2244 Add standard pbroadcast rules to more primitives. This should cover all primitives from which shard_map registered standard_rewrite rules
PiperOrigin-RevId: 741516445
2025-03-28 07:20:12 -07:00
Rachel Han
a52f7b26e7 Add accuracy field to unary ops
* Cbrt
  * Cos
  * Exp, Exp2
  * Expm1
  * Log
  * Logistic
  * Log1p
  * Rsqrt
  * Sin
  * Sqrt
  * Tan
  * Tanh
which allows users to select implementation that will satisfy the requested accuracy.

PiperOrigin-RevId: 741331787
2025-03-27 17:12:59 -07:00
Yash Katariya
25c106d132 Add standard_insert_pbroadcasts and standard_vma_rule to all primitives in following files: (Don't add standard_insert_broadcast for unary ops though)
* slicing.py
* windowed_reductions.py
* special.py
* convolution.py
* fft.py
* linalg.py
* ann.py

PiperOrigin-RevId: 741327361
2025-03-27 16:56:39 -07:00
Gunhyun Park
e1762b0af6 Assert unused variable in lax.all_to_all batching rule
P.S. minor improvement to code readability

PiperOrigin-RevId: 741051082
2025-03-27 00:47:13 -07:00
Daniel Suo
e364abe961 Prune passthrough outputs in lax.switch. 2025-03-26 18:53:14 +00:00
Yash Katariya
f1a9241187 Add standard_insert_broadcasts to all traceables in lax.py and checks in abstract_eval rules of those primitives.
PiperOrigin-RevId: 740536031
2025-03-25 17:03:18 -07:00
Yash Katariya
ed75189c92 [sharding_in_types] Add support for rng_bit_generator
PiperOrigin-RevId: 740492876
2025-03-25 14:48:27 -07:00