I expected Mosaic can canonicalize 2 same strided loads to one but it did not. (We will fix this in Mosaic). For now, manually converting to one strided load boosts 20~35% speedup in both v6e and v5e single chip for Meta-Llama-3-8B.
PiperOrigin-RevId: 745294058
Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
- especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
The main changes here are:
* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.
* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.
* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.
* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
Now that db11efab3b has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA.
There's no reason weakref_lru_cache is in the same Python extension as everything else.
PiperOrigin-RevId: 745271825
fixes#27683
In b7715e279, specifically this line:
b7715e279d (diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193)
we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in #27683. A repro looks like:
```python
import jax.numpy as jnp
from jax._src import core
v = core.mutable_array(jnp.array([0, 0, 0]))
v[...] += 1.0
print(v) # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32)
```
We can't easily just drop this behavior because it seems many GPU x64 tests depend on it.
So in this change we're trying to
1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype;
2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and
3. do an ordinary cast rather than a bitcast.
I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g.
```python
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0 # don't error!
```
But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment.
PiperOrigin-RevId: 745198669
See https://github.com/jax-ml/jax/pull/18711
check_rep uses rep=None to indicate when an argument is a constant, and that's useful specifically when checking the backward pass for integer_pow, which has a multiplication by a constant that didn't get a pbroadcast applied to it. That is, we use rep=None as a special carve-out for constants.
The standard rules were compatible with rep=None, but the rules for higher-order primitives like scan and cond were not. So we had to upgrade them.
After this change we no longer skip tests that required 'RunState`. This necessitated a small fix in the pallas lowering of `while` and also enabling multiple i32 register bundling in the `optimization_barrier` lowering.
PiperOrigin-RevId: 745065173
The ml_dtype requirement in JAX was updated to version 0.5.0+ (on Mar 20, 2025) - commit 4b7ead4
This update allows us to address previous FP8-related TODOs in jaxlib/XLA.
PiperOrigin-RevId: 744943824
But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**.
PiperOrigin-RevId: 744888557