As discovered in https://github.com/jax-ml/jax/issues/26216, for non-standard dtypes, calling `np.array` on a JAX array will unnecessarily cache the constructed `_npy_value` even when a copy isn't required. This change updates the logic to only save the cached value when it is a copy.
This fixes https://github.com/jax-ml/jax/issues/26216 by making the behavior consistent across dtypes, but we probably also want to expose a mechanism for clearing this cached value regardless.
PiperOrigin-RevId: 726522955
I'm working on implementing sharding logic across all of `lax.linalg`, and I've found that the previous implementation of this loop using explicit broadcasted iotas was confounding the partitioner, but this version using vmap batch partitions properly and I don't anticipate any performance differences.
PiperOrigin-RevId: 726518677
This requires unrolling into two instructions sequences over n, since the largest
tcgen05.mma instructions can only handle n=256.
PiperOrigin-RevId: 726496900
It happens rarely, but this test seems to be flaky, probably because we don't
properly synchronize the memory accesses somehow. It's not important, so we
just avoid the data races now.
PiperOrigin-RevId: 726489606
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,
def f(x: int) -> str: ...
def g(x: int) -> str: ...
callback = f if ... else g # has type object!
Apparently it's a requirement that LLVM has, but it only complains about it when compiled with
assertions enabled, so it went unnoticed for a while.
PiperOrigin-RevId: 726468259
We only export symbols that being with `mlir` and a few other prefixes, so this renames our C API functions for consistency with that.
PiperOrigin-RevId: 726468092
This change is raising a better error because doing `NamedSharding(empty_mesh, P('x'))` will raise an error on construction but it is uglier than the current error added in this change.
PiperOrigin-RevId: 726253654
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
Unlike continuous workflows, when testing nightly/release artifacts, we want to download and install the `jax` wheels found in the GCS bucket instead of installing it from HEAD.
It looks like `env` setting in the calling workflow isn't passed over to the called workflows so we define a new workflow input, `install-jax-current-commit`, to control the `jax` install behavior.
PiperOrigin-RevId: 726086522
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.
This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.
PiperOrigin-RevId: 726030853
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
We need to wait for the results to be available before we can copy them to
`smem`, and these instructions are not issued in the lowering of
`mgpu_dialect.wgmma`.
PiperOrigin-RevId: 725989759