Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483
For future reference, this can be done via
python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
while IFS=: read file line rest; do
echo "$file:$line";
gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
done < /tmp/unused.txt
As discovered in https://github.com/jax-ml/jax/issues/26216, for non-standard dtypes, calling `np.array` on a JAX array will unnecessarily cache the constructed `_npy_value` even when a copy isn't required. This change updates the logic to only save the cached value when it is a copy.
This fixes https://github.com/jax-ml/jax/issues/26216 by making the behavior consistent across dtypes, but we probably also want to expose a mechanism for clearing this cached value regardless.
PiperOrigin-RevId: 726522955
I'm working on implementing sharding logic across all of `lax.linalg`, and I've found that the previous implementation of this loop using explicit broadcasted iotas was confounding the partitioner, but this version using vmap batch partitions properly and I don't anticipate any performance differences.
PiperOrigin-RevId: 726518677
This requires unrolling into two instructions sequences over n, since the largest
tcgen05.mma instructions can only handle n=256.
PiperOrigin-RevId: 726496900
It happens rarely, but this test seems to be flaky, probably because we don't
properly synchronize the memory accesses somehow. It's not important, so we
just avoid the data races now.
PiperOrigin-RevId: 726489606
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,
def f(x: int) -> str: ...
def g(x: int) -> str: ...
callback = f if ... else g # has type object!
Apparently it's a requirement that LLVM has, but it only complains about it when compiled with
assertions enabled, so it went unnoticed for a while.
PiperOrigin-RevId: 726468259
We only export symbols that being with `mlir` and a few other prefixes, so this renames our C API functions for consistency with that.
PiperOrigin-RevId: 726468092
This change is raising a better error because doing `NamedSharding(empty_mesh, P('x'))` will raise an error on construction but it is uglier than the current error added in this change.
PiperOrigin-RevId: 726253654
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
Unlike continuous workflows, when testing nightly/release artifacts, we want to download and install the `jax` wheels found in the GCS bucket instead of installing it from HEAD.
It looks like `env` setting in the calling workflow isn't passed over to the called workflows so we define a new workflow input, `install-jax-current-commit`, to control the `jax` install behavior.
PiperOrigin-RevId: 726086522