Yash Katariya
7de522c5a3
Enter into auto mode for .at[...].get(...)
a bit earlier so that all ops inside _gather
are in auto mode.
...
Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`.
PiperOrigin-RevId: 748867490
2025-04-17 17:42:13 -07:00
Yash Katariya
06ad3528e9
Use _make_lengths_same for explicit mode too.
...
We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs.
PiperOrigin-RevId: 748735303
2025-04-17 10:48:46 -07:00
Sergei Lebedev
c576d328bd
Added lax.axis_size
and switched all existing usage of psum(1, ...)
to it
...
PiperOrigin-RevId: 748604842
2025-04-17 02:22:25 -07:00
Yash Katariya
82215f660e
Remove jax_varying_axes_in_types config and rewrite
from shard_map_p
...
PiperOrigin-RevId: 748545142
2025-04-16 22:27:50 -07:00
Yash Katariya
5f6b99a143
Fix a bug in reduce_window sharding rule where padding is a tuple but we were checking for a scalar instead. Fixes https://github.com/jax-ml/jax/issues/28070
...
PiperOrigin-RevId: 748418451
2025-04-16 14:10:13 -07:00
Yash Katariya
655bfcac39
Enable standard_insert_pvary for optimization_barrier which was disabled before.
...
PiperOrigin-RevId: 748027360
2025-04-15 14:41:08 -07:00
Yash Katariya
6e00b5e02d
[NFC] Rename standard_insert_pbroadcast
to standard_insert_pvary
...
PiperOrigin-RevId: 747943230
2025-04-15 11:02:45 -07:00
jax authors
b6c6c1c258
Merge pull request #27971 from ywrt:patch-1
...
PiperOrigin-RevId: 747399343
2025-04-14 07:00:10 -07:00
Matthew Johnson
b3f49e42d9
Re-landing #27937 with fewer bugs and more tests.
2025-04-11 22:42:08 +00:00
ywrt
c90751bc54
Fix typo in jax.lax.linalg.symmetric_product description
...
Missing space in '..math::' meant that the math wasn't rendering correctly.
2025-04-12 07:20:39 +10:00
Matthew Johnson
e9364f4b0a
Reverts 907725dfd7a7fb612c4f6d975bb462f1ae1a21d7
...
PiperOrigin-RevId: 746554582
2025-04-11 12:37:20 -07:00
jax authors
907725dfd7
Merge pull request #27937 from mattjj:while-readonly-carry-optimization
...
PiperOrigin-RevId: 746250385
2025-04-10 18:29:49 -07:00
Matthew Johnson
6e52b1e95b
optimize while_loop by moving readonly carry components to be consts
...
also fix a bug in ordered effects in cond_fun lowering
fixes google/flax#4700
2025-04-11 00:48:52 +00:00
Dan Foreman-Mackey
f3115d32a2
Fix dtype failures in JaxGroupedQueryAttentionReferenceTest.
...
PiperOrigin-RevId: 746097962
2025-04-10 11:04:43 -07:00
jax authors
9011d66a29
Merge pull request #27903 from mattjj:pvary-errors
...
PiperOrigin-RevId: 746070501
2025-04-10 09:56:16 -07:00
Yash Katariya
6c0ac7a503
Do a pvary in dynamic_slice_transpose_rule so that the zeros
are varying with the correct vma as the operands were.
...
PiperOrigin-RevId: 746065965
2025-04-10 09:43:17 -07:00
Adam Paszke
7e2148b800
[Pallas:MGPU] Don't assume we'll be running at least max_concurrent_steps in the memory WG
...
max_concurrent_steps is an upper bound: we no longer guarantee that it accurately
reflects the actual number of steps when the grid has dynamic bounds
PiperOrigin-RevId: 746036125
2025-04-10 08:12:46 -07:00
Matthew Johnson
892cb65308
[shard-map] good errors for pvary issues
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2025-04-10 01:25:49 +00:00
Yash Katariya
76c6b5b00d
More changes for enabling vma
by default in JAX
...
PiperOrigin-RevId: 745621495
2025-04-09 09:33:33 -07:00
Yash Katariya
f95f6a8bdb
Rename psum2
to psum_invariant
and put it in lax_parallel
. We shouldn't expose this to public API and have users use psum
instead which will dispatch to psum_invariant
when check_rep=True
.
...
PiperOrigin-RevId: 745352875
2025-04-08 17:28:59 -07:00
Yash Katariya
84016bc368
Rename pbroadcast
to pvary
and expose it as jax.lax.pvary
.
...
PiperOrigin-RevId: 745342103
2025-04-08 16:51:27 -07:00
Yash Katariya
8301c304c1
Make changes to shard_map to prepare for setting varying_axes_in_types
to True.
...
The main changes here are:
* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.
* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.
* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.
* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
2025-04-08 13:47:13 -07:00
Peter Hawkins
e02faabfb2
Replace references to jax.readthedocs.io with docs.jax.dev.
...
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Rachel Han
84e04fe608
Add custom pretty print rule for the unary ops with accuracy s.t. accuracy is not printed if it's None.
...
PiperOrigin-RevId: 744889524
2025-04-07 16:25:01 -07:00
jax authors
48a9ad0796
Reverts 006a6a63feb64bf9984526030ba008186d69d2b4
...
PiperOrigin-RevId: 744864022
2025-04-07 15:08:33 -07:00
Roy Frostig
f8bbe98a86
require out_shardings
as a keyword-only argument on public functions
...
PiperOrigin-RevId: 743753215
2025-04-03 17:26:05 -07:00
jax authors
45e6808bb5
Merge pull request #27084 from danielsuo:switch-fwd
...
PiperOrigin-RevId: 743452172
2025-04-03 01:07:50 -07:00
Gunhyun Park
92f7aeab48
Add simple vmap support for lax.ragged_all_to_all.
...
PiperOrigin-RevId: 743230485
2025-04-02 12:10:34 -07:00
Yash Katariya
3d70fc8197
Add pbroadcast insertion for psum_p
in the traceable. This effectively replaces psum_p
with psum2_p
if varying_axes_in_types
is on. psum_p can be replaced with psum2_p in follow up CLs
...
Also populate the aval of `ShardMapTracer` with `vma`
PiperOrigin-RevId: 743188081
2025-04-02 10:21:21 -07:00
Yash Katariya
10b2cda90e
Relax the aval check in select_hlo_lowering_opaque
to only check for shardings if they are not empty. The same thing happens in select_p's sharding rule
...
PiperOrigin-RevId: 743105350
2025-04-02 06:10:33 -07:00
Matthew Johnson
a80f6279e9
make random_gamma_grad not a primitive anymore
...
Fixes #16076
Co-authored-by: Roy Frostig <frostig@google.com>
2025-04-01 17:04:50 +00:00
Yash Katariya
76271d638a
Add scan_p and cond_p vma rule.
...
PiperOrigin-RevId: 742737384
2025-04-01 09:50:38 -07:00
jax authors
006a6a63fe
[Easy] Make pallas mesh grid handling more resilient to tuple names.
...
PiperOrigin-RevId: 742531956
2025-03-31 22:02:29 -07:00
Yash Katariya
d6b4fed5ed
Propagate sharding and vma rule for axis_index_p. There's no need for pbroadcast insertion for axis_index_p in the traceable
...
PiperOrigin-RevId: 742334213
2025-03-31 11:33:59 -07:00
Yash Katariya
7ca50844f3
Fix an edge-case in reshape sharding rule where the last splitting/merging dim was 1
.
...
PiperOrigin-RevId: 741740811
2025-03-28 21:43:27 -07:00
Yash Katariya
80061ad4c4
Add vma rules for pmin and pmax
...
PiperOrigin-RevId: 741685454
2025-03-28 16:55:16 -07:00
Matthew Johnson
6fba4ecc58
PR #27576 : [attrs] experimental appendattr
...
Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576
This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated.
This PR also includes some fixes for getattr/setattr.
Copybara import of the project:
--
3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson <mattjj@google.com>:
[attrs] experimental appendattr
Merging this change closes #27576
COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8
PiperOrigin-RevId: 741662724
2025-03-28 15:21:12 -07:00
Yash Katariya
177193662c
Add vma rules for all_gather, all_to_all, ppermute and reduce_scatter primitives
...
PiperOrigin-RevId: 741661360
2025-03-28 15:16:06 -07:00
jax authors
2d63b6e56d
Merge pull request #27583 from jakevdp:scan-doc
...
PiperOrigin-RevId: 741653320
2025-03-28 14:45:52 -07:00
jax authors
6edc31ae1d
Merge pull request #27525 from jakevdp:ml-dtypes-cleanup
...
PiperOrigin-RevId: 741651222
2025-03-28 14:38:38 -07:00
Jake VanderPlas
91dac631fb
scan: improve docs & errors around dynamic length
2025-03-28 14:15:25 -07:00
Yash Katariya
5950e722e2
Make sure vma
on ShapedArray exists by default to make development easier. The field is populated inside shard_map guarded on the varying_axes_in_types config though.
...
PiperOrigin-RevId: 741554623
2025-03-28 09:44:03 -07:00
Jake VanderPlas
431c2c0807
cleanup now that we depend on ml_dtypes>=0.5
2025-03-28 07:44:38 -07:00
Yash Katariya
563c3e2244
Add standard pbroadcast rules to more primitives. This should cover all primitives from which shard_map registered standard_rewrite rules
...
PiperOrigin-RevId: 741516445
2025-03-28 07:20:12 -07:00
Rachel Han
a52f7b26e7
Add accuracy field to unary ops
...
* Cbrt
* Cos
* Exp, Exp2
* Expm1
* Log
* Logistic
* Log1p
* Rsqrt
* Sin
* Sqrt
* Tan
* Tanh
which allows users to select implementation that will satisfy the requested accuracy.
PiperOrigin-RevId: 741331787
2025-03-27 17:12:59 -07:00
Yash Katariya
25c106d132
Add standard_insert_pbroadcasts and standard_vma_rule to all primitives in following files: (Don't add standard_insert_broadcast
for unary ops though)
...
* slicing.py
* windowed_reductions.py
* special.py
* convolution.py
* fft.py
* linalg.py
* ann.py
PiperOrigin-RevId: 741327361
2025-03-27 16:56:39 -07:00
Gunhyun Park
e1762b0af6
Assert unused variable in lax.all_to_all batching rule
...
P.S. minor improvement to code readability
PiperOrigin-RevId: 741051082
2025-03-27 00:47:13 -07:00
Daniel Suo
e364abe961
Prune passthrough outputs in lax.switch.
2025-03-26 18:53:14 +00:00
Yash Katariya
f1a9241187
Add standard_insert_broadcasts to all traceables in lax.py and checks in abstract_eval rules of those primitives.
...
PiperOrigin-RevId: 740536031
2025-03-25 17:03:18 -07:00
Yash Katariya
ed75189c92
[sharding_in_types] Add support for rng_bit_generator
...
PiperOrigin-RevId: 740492876
2025-03-25 14:48:27 -07:00