max_concurrent_steps is an upper bound: we no longer guarantee that it accurately
reflects the actual number of steps when the grid has dynamic bounds
PiperOrigin-RevId: 746036125
Without this, `emit_pipeline_warp_specialized` would leave the barriers in a bad
state, causing deadlocks or crashes when it was called multiple times in sequence.
PiperOrigin-RevId: 746022784
improve layout, information
add dummy import to hopefully fix build issue
parse help text for markdown
whoops didn't mean to do it twice
jax prefix text no longer applies here
two space indents
address definition list ending without blank line error
provide deprecation mechanism
document context managagers if they exist
remove mention of context manager
try and fix formatting
improve formatting, fail to fix warnings
fail to fix bug, make better looking anyway
okay bug was in the parsing of help text to rst, some of which does not parse
wow, found the bug, turns out help strings were not valid rst
This allows unreserving the barrier once it is no longer needed and is consistent
with how resource estimation works, e.g. for `cond`.
PiperOrigin-RevId: 745483567
Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
- especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
The main changes here are:
* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.
* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.
* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.
* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
Now that db11efab3b has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA.
There's no reason weakref_lru_cache is in the same Python extension as everything else.
PiperOrigin-RevId: 745271825
fixes#27683
In b7715e279, specifically this line:
b7715e279d (diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193)
we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in #27683. A repro looks like:
```python
import jax.numpy as jnp
from jax._src import core
v = core.mutable_array(jnp.array([0, 0, 0]))
v[...] += 1.0
print(v) # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32)
```
We can't easily just drop this behavior because it seems many GPU x64 tests depend on it.
So in this change we're trying to
1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype;
2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and
3. do an ordinary cast rather than a bitcast.
I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g.
```python
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0 # don't error!
```
But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment.
PiperOrigin-RevId: 745198669
After this change we no longer skip tests that required 'RunState`. This necessitated a small fix in the pallas lowering of `while` and also enabling multiple i32 register bundling in the `optimization_barrier` lowering.
PiperOrigin-RevId: 745065173