8681 Commits

Author SHA1 Message Date
jax authors
e9ce8fb92d Merge pull request #27227 from jburnim:jburnim_pallas_interpret_mode4
PiperOrigin-RevId: 738235363
2025-03-18 20:22:27 -07:00
Sharad Vikram
e949effcda [Pallas/Fuser] DCE fusion jaxprs before pulling (to avoid unnecessary computations being staged out in block functions)
PiperOrigin-RevId: 738218113
2025-03-18 19:00:41 -07:00
Sharad Vikram
4d715753c4 Make sure to DCE read effects
PiperOrigin-RevId: 738215055
2025-03-18 18:42:14 -07:00
Yash Katariya
663ef7ae01 Check the type of mesh in use_abstract_mesh and use_concrete_mesh
PiperOrigin-RevId: 738190879
2025-03-18 16:57:40 -07:00
Peter Hawkins
3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
jax authors
01a110c4c9 Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.
PiperOrigin-RevId: 738171437
2025-03-18 15:51:15 -07:00
Parker Schuh
0fb59747f0 Support tuples in custom_partitioning.
PiperOrigin-RevId: 738154413
2025-03-18 14:57:08 -07:00
Matthew Johnson
942ff38e36 fix to ragged_all_to_all transpose
PiperOrigin-RevId: 738110447
2025-03-18 12:51:21 -07:00
Jacob Burnim
47e8effdce Adds option to initialize buffers to NaNs or zeros in TPU interpret mode. 2025-03-18 12:24:45 -07:00
Yash Katariya
a5c0f200e7 set_mesh should return the prev_mesh instead of nothing. Users can choose to use the return value or ignore it.
PiperOrigin-RevId: 738039559
2025-03-18 09:43:25 -07:00
jax authors
7c5871f464 [Pallas TPU] Hoist prologue and epilogue outside of pipeline loop
PiperOrigin-RevId: 738038138
2025-03-18 09:40:43 -07:00
jax authors
30941480a1 Merge pull request #27198 from jakevdp:lax-docs
PiperOrigin-RevId: 738038116
2025-03-18 09:38:58 -07:00
jax authors
13541e9f12 Make blocked_fold_in consistent when the block sizes induce padding
Add coverage for padded shapes to unit tests.

PiperOrigin-RevId: 738029476
2025-03-18 09:12:11 -07:00
Jake VanderPlas
8b46e53a4f jax.lax: improve docs for several APIs 2025-03-18 08:55:38 -07:00
Yash Katariya
549973dec6 Allow pspec to be passed to device_put if there is a mesh in the surrounding context
PiperOrigin-RevId: 737812111
2025-03-17 17:47:56 -07:00
Emily Fertig
8c35191725 Enable jax.device_put to a sharding with no local devices.
PiperOrigin-RevId: 737797815
2025-03-17 16:49:46 -07:00
Sergei Lebedev
051687dc4c [pallas] pallas_call_p is now parameterized by a mesh
The mesh is necessary to add support for clusters to the Mosaic GPU backend.

PiperOrigin-RevId: 737792129
2025-03-17 16:30:40 -07:00
jax authors
b4966130a3 Compute tile index using tile-based coordinates
This reduces the chances of overflowing a 32-bit integer when computing tile indices.
Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`.

PiperOrigin-RevId: 737778853
2025-03-17 15:46:27 -07:00
Peter Hawkins
20658fabb3 Replace cached function get_replicated_hlo_sharding() with a constant.
Small cleanup, no functional changes intended.

PiperOrigin-RevId: 737727727
2025-03-17 13:17:33 -07:00
jax authors
ebcae0d30a Merge pull request #26980 from carlosgmartin:categorical_replace
PiperOrigin-RevId: 737720590
2025-03-17 12:58:01 -07:00
Peter Hawkins
be5d13af77 Remove code that preserved _original_py_fns on C++ classes.
This no longer appears to be used.

PiperOrigin-RevId: 737715578
2025-03-17 12:43:04 -07:00
carlosgmartin
3f59fa6888 Add replace option to random.categorical to enable sampling without replacement. 2025-03-17 13:41:46 -04:00
jax authors
de9ad6bad9 Merge pull request #27157 from mar-muel:improve-random-choice-performance
PiperOrigin-RevId: 737665351
2025-03-17 10:30:15 -07:00
Sergei Lebedev
a7e5eaee56 [pallas:mosaic_gpu] jnp.reduce_sum now works for >1D arrays
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00
jax authors
761b35c59e Merge pull request #27176 from jakevdp:lax-docs
PiperOrigin-RevId: 737338493
2025-03-16 05:39:55 -07:00
Joan Puigcerver
466ef6a132 Change the way that batching.spec_types is updated.
There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions.

Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove).

When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types.

PiperOrigin-RevId: 737280270
2025-03-15 22:58:44 -07:00
Jake VanderPlas
de8b0564ce Better docs for jax.lax add/sub/mul/div 2025-03-15 11:49:51 -07:00
Ayaka
9b0ace4a11 Support error checking in explicit mode
PiperOrigin-RevId: 737051146
2025-03-14 18:58:26 -07:00
jax authors
7db59cdcca Merge pull request #27174 from mattjj:opt-barrier-ad-rules
PiperOrigin-RevId: 737040381
2025-03-14 17:59:07 -07:00
Peter Hawkins
14cb7453f0 Add a C++ implementation of a toplogical sort.
This is an exact port of the current Python implementation to C++ for speed.

I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.

PiperOrigin-RevId: 737014989
2025-03-14 16:04:25 -07:00
Matthew Johnson
dadc68b6c1 add experimental lax.optimization_barrier autodiff rules 2025-03-14 22:40:55 +00:00
jax authors
b00a3a1986 Merge pull request #27015 from mattjj:direct-linearize-fixes-4
PiperOrigin-RevId: 737003323
2025-03-14 15:24:11 -07:00
Sergei Lebedev
64230d1c93 [pallas:mosaic_gpu] WG lowering now supports while_p
PiperOrigin-RevId: 736996154
2025-03-14 14:59:29 -07:00
Matthew Johnson
174dcc771a [direct-linearize] shmap fixes 2025-03-14 21:38:50 +00:00
Justin Fu
dbd8d92075 [Pallas] Add legacy PRNG key support to Pallas PRNG
PiperOrigin-RevId: 736949584
2025-03-14 12:30:04 -07:00
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Emily Fertig
bdb6d03322 Allow make_array_from_callback to construct nonaddressable arrays.
PiperOrigin-RevId: 736922870
2025-03-14 11:10:32 -07:00
Martin Muller
4a82fe94de Use lax.top_k instead of jnp.argsort in Gumbel top-k trick for weighted sampling without replacement in jax.random.choice 2025-03-14 19:02:24 +01:00
Ilya Tikhonovskiy
c9ac82c826 [XLA:GPU] Add missing BF16_BF16_F32_X9 matmul option in config.py
Extend the list of possible default algorithms that dot could use.

PiperOrigin-RevId: 736879149
2025-03-14 08:58:59 -07:00
Peter Hawkins
6fa98fc0a4 Use "x is y" rather than "id(x) == id(y)".
The latter involves at least two object constructions.

PiperOrigin-RevId: 736878098
2025-03-14 08:54:46 -07:00
Peter Hawkins
8ab33669e2 Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.

I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.

PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
Peter Hawkins
074216e07a Precompute a weakref to a Trace≥
We use Trace weakrefs frequently, so we may as well construct one eagerly.

PiperOrigin-RevId: 736841778
2025-03-14 06:26:17 -07:00
Ilya Tikhonovskiy
43b78c539f [JAX] Add missing preset for X9 dot optimization on BF16/BF16 -> F32.
Two PRs that support this feature have been submitted to stablehlo and openxla.
Now we could do the last step - enable it in JAX.

PiperOrigin-RevId: 736799241
2025-03-14 02:57:55 -07:00
jax authors
cbece0b00b Add explicit support for float8_e4m3b11fnuz in pl.dot
PiperOrigin-RevId: 736798315
2025-03-14 02:51:55 -07:00
Emily Fertig
d79472101d Plumb layout through the creation of IFRT Arrays (roll-forward with fix).
Reverts 7f9e7473cfe7e2b3c4eb43ce6df916b3159c1cff

PiperOrigin-RevId: 736739556
2025-03-13 21:32:52 -07:00
Yash Katariya
d3a41d8448 get_sharding doesn't need to be conditioned on the context mesh
PiperOrigin-RevId: 736710468
2025-03-13 18:59:31 -07:00
Yash Katariya
e615e2acb3 Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
Previously it was:

`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`

Now it is:

`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`

PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Peter Hawkins
1507754408 Precompute the __hash__ of AbstractMesh.
We use this frequently and it saves time to precompute it.

PiperOrigin-RevId: 736650750
2025-03-13 15:01:31 -07:00
Zac Mustin
acd6c40f2f Remove obsolete fallback for cost analysis.
This fallback does not seem to be needed as all executables have a cost-analysis implementation.

PiperOrigin-RevId: 736647203
2025-03-13 14:49:40 -07:00
Yash Katariya
e1b62cede1 Raise an error if jax.config.update('jax_num_cpu_devices', val) is called after backend is initialized
PiperOrigin-RevId: 736646012
2025-03-13 14:45:53 -07:00