This change primarily reduces sharding, although in a few cases it also increases shardings. It is harmful to performance to overshard tests since there's a startup and teardown cost to each test run.
In a few cases, change tests to be non-accelerator tests.
PiperOrigin-RevId: 746164539
Adds a new WarpMesh object which when used in conjunction with core_map, allows the user to drop into warp-level code rather than programming at the warpgroup level.
PiperOrigin-RevId: 746163942
This fixes some non-intuitive errors where scalar-shaped values in VREGs were being used in operations that expected SREGs.
PiperOrigin-RevId: 746146037
If mesh axes are empty, we are setting mesh as None, resulting in an error in
this test.
This fix provides an empty mesh, when no mesh axes in dumped module are empty.
PiperOrigin-RevId: 746058506
max_concurrent_steps is an upper bound: we no longer guarantee that it accurately
reflects the actual number of steps when the grid has dynamic bounds
PiperOrigin-RevId: 746036125
Without this, `emit_pipeline_warp_specialized` would leave the barriers in a bad
state, causing deadlocks or crashes when it was called multiple times in sequence.
PiperOrigin-RevId: 746022784
This allows unreserving the barrier once it is no longer needed and is consistent
with how resource estimation works, e.g. for `cond`.
PiperOrigin-RevId: 745483567
Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
- especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
The main changes here are:
* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.
* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.
* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.
* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474