10094 Commits

Author SHA1 Message Date
jax authors
7b7d36a8e6 Add a 2D test in memories_test.
PiperOrigin-RevId: 746295338
2025-04-10 21:32:56 -07:00
Ayaka
9f5f6edb85 [Pallas] Fix integer array indexing
Fixes https://github.com/google/jax/issues/22783

jax-fixit

PiperOrigin-RevId: 746260869
2025-04-10 19:10:35 -07:00
jax authors
907725dfd7 Merge pull request #27937 from mattjj:while-readonly-carry-optimization
PiperOrigin-RevId: 746250385
2025-04-10 18:29:49 -07:00
Matthew Johnson
6e52b1e95b optimize while_loop by moving readonly carry components to be consts
also fix a bug in ordered effects in cond_fun lowering

fixes google/flax#4700
2025-04-11 00:48:52 +00:00
Peter Hawkins
b352763a17 Fix Pallas tests so they work with JAX_TEST_NUM_THREADS >= 1.
PiperOrigin-RevId: 746226562
2025-04-10 16:57:34 -07:00
Christos Perivolaropoulos
41a8805d96 [pallas:mgpu] Return types allowed in mgpu.inline_mgpu.
PiperOrigin-RevId: 746217405
2025-04-10 16:28:34 -07:00
Peter Hawkins
cf8a52463c Update test shardings.
This change primarily reduces sharding, although in a few cases it also increases shardings. It is harmful to performance to overshard tests since there's a startup and teardown cost to each test run.

In a few cases, change tests to be non-accelerator tests.

PiperOrigin-RevId: 746164539
2025-04-10 14:01:15 -07:00
Justin Fu
92be510f0b [Mosaic GPU] Implement warp-level thread semantics.
Adds a new WarpMesh object which when used in conjunction with core_map, allows the user to drop into warp-level code rather than programming at the warpgroup level.

PiperOrigin-RevId: 746163942
2025-04-10 13:59:22 -07:00
Justin Fu
2807ae4e34 [Pallas] Fix ()-shaped vectors being materialized in Pallas lowering.
This fixes some non-intuitive errors where scalar-shaped values in VREGs were being used in operations that expected SREGs.

PiperOrigin-RevId: 746146037
2025-04-10 13:13:30 -07:00
Justin Fu
7117aa03fa [Mosaic GPU] Skip WGMMA with cluster example on non H100 GPUs.
PiperOrigin-RevId: 746140286
2025-04-10 12:57:24 -07:00
Dougal Maclaurin
5f5e742368 Mark as thread-unsafe tests that modify possibly-cached jaxprs in-place.
PiperOrigin-RevId: 746112248
2025-04-10 11:40:42 -07:00
Dan Foreman-Mackey
f3115d32a2 Fix dtype failures in JaxGroupedQueryAttentionReferenceTest.
PiperOrigin-RevId: 746097962
2025-04-10 11:04:43 -07:00
jax authors
9011d66a29 Merge pull request #27903 from mattjj:pvary-errors
PiperOrigin-RevId: 746070501
2025-04-10 09:56:16 -07:00
Yash Katariya
6c0ac7a503 Do a pvary in dynamic_slice_transpose_rule so that the zeros are varying with the correct vma as the operands were.
PiperOrigin-RevId: 746065965
2025-04-10 09:43:17 -07:00
Kostiantyn Liepieshov
c730bbda74 fix bug in export_module when no mesh axes are empty for shardy.
If mesh axes are empty, we are setting mesh as None, resulting in an error in
this test.

This fix provides an empty mesh, when no mesh axes in dumped module are empty.

PiperOrigin-RevId: 746058506
2025-04-10 09:21:58 -07:00
jax authors
6dd576acd5 Add unit tests for the grouped query attention reference implementation
PiperOrigin-RevId: 746057793
2025-04-10 09:19:58 -07:00
jax authors
e287c7ffc0 Minor adjustments in error messages in launch_context.py
PiperOrigin-RevId: 746056606
2025-04-10 09:16:14 -07:00
Adam Paszke
7e2148b800 [Pallas:MGPU] Don't assume we'll be running at least max_concurrent_steps in the memory WG
max_concurrent_steps is an upper bound: we no longer guarantee that it accurately
reflects the actual number of steps when the grid has dynamic bounds

PiperOrigin-RevId: 746036125
2025-04-10 08:12:46 -07:00
jax authors
95f1207fbf Merge pull request #27843 from dfm:lin-call-jvp
PiperOrigin-RevId: 746026464
2025-04-10 07:43:50 -07:00
Adam Paszke
ec59178d29 [Pallas:MGPU] Make sure to await all arrivals on consumed barriers
Without this, `emit_pipeline_warp_specialized` would leave the barriers in a bad
state, causing deadlocks or crashes when it was called multiple times in sequence.

PiperOrigin-RevId: 746022784
2025-04-10 07:32:30 -07:00
Dan Foreman-Mackey
e1aa83ad67 Add JVP rule for linear_call. 2025-04-10 09:12:01 -04:00
jax authors
f7a2760822 Merge pull request #27831 from dfm:linear-call-recursion
PiperOrigin-RevId: 745992513
2025-04-10 05:50:34 -07:00
jax authors
cf268a7f6a Merge pull request #27895 from jakevdp:random-test-refactor
PiperOrigin-RevId: 745813055
2025-04-09 18:43:46 -07:00
Matthew Johnson
892cb65308 [shard-map] good errors for pvary issues
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2025-04-10 01:25:49 +00:00
Yash Katariya
93202140a4 Make printing work with shard_map after vma has been switched on
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745796269
2025-04-09 17:33:29 -07:00
Jake VanderPlas
c5bd13bbfc Refactor random_lax_test.py 2025-04-09 15:33:04 -07:00
Yash Katariya
75e4279e32 Set jax_varying_axes_in_types to True by default.
PiperOrigin-RevId: 745739477
2025-04-09 14:40:31 -07:00
Kevin Gleason
7c1595ac61 Skip jax/tests:unary_ops_accuracy_test when running with older versions of StableHLO.
PiperOrigin-RevId: 745717137
2025-04-09 13:42:25 -07:00
jax authors
c418495b13 Merge pull request #27886 from mattjj:26763
PiperOrigin-RevId: 745704582
2025-04-09 13:09:11 -07:00
jax authors
2863b48801 Merge pull request #27759 from mattjj:vmappable-bind-fix
PiperOrigin-RevId: 745692664
2025-04-09 12:37:03 -07:00
Matthew Johnson
2b3839d248 [shard-map] make shard_map work with custom_jvp symbolic zeros
also resolves a TODO(mattjj,frostig) from #14570 to make vmap-of-custom_jvp not instantiate symbolic zeros

fixes #26763
2025-04-09 19:28:37 +00:00
jax authors
038566713b Merge pull request #27853 from carlosgmartin:merge_tuple_update_tuple_replace
PiperOrigin-RevId: 745683630
2025-04-09 12:11:52 -07:00
Daniel Suo
e750d7ead6 Add option for debug print to be called on partitioned arguments rather than all-gather and print the entire logical arguments.
PiperOrigin-RevId: 745678930
2025-04-09 12:00:00 -07:00
jax authors
21a4429a41 Merge pull request #27879 from mattjj:shmap-fix-5
PiperOrigin-RevId: 745647632
2025-04-09 10:39:49 -07:00
carlosgmartin
b6a46310d1 Merge tuple_replace and tuple_update in jax._src.util. 2025-04-09 12:50:42 -04:00
Matthew Johnson
8383af0145 [shard-map] fix another bug where we incorrectly handled None in check_rep
this was essentially another instance of the #27797 fix

fixes #24762
2025-04-09 16:38:24 +00:00
Yash Katariya
76c6b5b00d More changes for enabling vma by default in JAX
PiperOrigin-RevId: 745621495
2025-04-09 09:33:33 -07:00
Dan Foreman-Mackey
e0cda84d90 Fix linear_call to allow recursive definitions. 2025-04-09 09:45:32 -04:00
Dimitar (Mitko) Asenov
9adc3cc05c [Mosaic GPU] Add a LayoutCast op to the Mosaic GPU mlir dialect.
PiperOrigin-RevId: 745557539
2025-04-09 06:11:15 -07:00
Sergei Lebedev
866e32b329 [pallas:mosaic_gpu] ModuleContext.reserve_barrier is now a context manager
This allows unreserving the barrier once it is no longer needed and is consistent
with how resource estimation works, e.g. for `cond`.

PiperOrigin-RevId: 745483567
2025-04-09 01:45:43 -07:00
Yash Katariya
f95f6a8bdb Rename psum2 to psum_invariant and put it in lax_parallel. We shouldn't expose this to public API and have users use psum instead which will dispatch to psum_invariant when check_rep=True.
PiperOrigin-RevId: 745352875
2025-04-08 17:28:59 -07:00
Yash Katariya
84016bc368 Rename pbroadcast to pvary and expose it as jax.lax.pvary.
PiperOrigin-RevId: 745342103
2025-04-08 16:51:27 -07:00
jax authors
373ac2ef7e Merge pull request #27804 from vfdev-5:ft-adapt-state-test-2
PiperOrigin-RevId: 745341315
2025-04-08 16:48:54 -07:00
jax authors
b8d9e7f427 Merge pull request #27503 from kaixih:enable_doc_scaled_dot
PiperOrigin-RevId: 745322012
2025-04-08 15:50:54 -07:00
vfdev-5
5a340a9781 Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true
Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
  - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
2025-04-08 21:02:55 +00:00
Yash Katariya
8301c304c1 Make changes to shard_map to prepare for setting varying_axes_in_types to True.
The main changes here are:

* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.

* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.

* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.

* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
2025-04-08 13:47:13 -07:00
Dan Foreman-Mackey
2d44f985c3 Finalize deprecation of ffi_call with inline arguments.
PiperOrigin-RevId: 745261995
2025-04-08 13:09:42 -07:00
jax authors
76825a2d45 Merge pull request #27807 from jakevdp:eigvalsh-symmetrize
PiperOrigin-RevId: 745216021
2025-04-08 11:09:58 -07:00
jax authors
b073e8db25 Merge pull request #27836 from jakevdp:fix-repeat
PiperOrigin-RevId: 745215941
2025-04-08 11:07:57 -07:00
jax authors
ef68063497 Merge pull request #27809 from mattjj:26621
PiperOrigin-RevId: 745212009
2025-04-08 10:57:57 -07:00