In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs.
PiperOrigin-RevId: 726920440
We don't support Windows GPU builds right now and skip all the tests,
but at the moment they can't even skip because of the import failure.
PiperOrigin-RevId: 726917651
The corresponding Triton op is restricted to `jnp.stack([x, y], axis=-1)`,
so the lowering only supports that case for now.
See #25321.
PiperOrigin-RevId: 726881284
The PTX guide talks about a few layouts by assigning them different
letters, which do not have an obvious meaning. We redefine the layout
by parameterizing it with a 2D tile size which, as far as I can tell,
is sufficient to represent all layouts we care about.
PiperOrigin-RevId: 726833412
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483
For future reference, this can be done via
python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
while IFS=: read file line rest; do
echo "$file:$line";
gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
done < /tmp/unused.txt
As discovered in https://github.com/jax-ml/jax/issues/26216, for non-standard dtypes, calling `np.array` on a JAX array will unnecessarily cache the constructed `_npy_value` even when a copy isn't required. This change updates the logic to only save the cached value when it is a copy.
This fixes https://github.com/jax-ml/jax/issues/26216 by making the behavior consistent across dtypes, but we probably also want to expose a mechanism for clearing this cached value regardless.
PiperOrigin-RevId: 726522955
I'm working on implementing sharding logic across all of `lax.linalg`, and I've found that the previous implementation of this loop using explicit broadcasted iotas was confounding the partitioner, but this version using vmap batch partitions properly and I don't anticipate any performance differences.
PiperOrigin-RevId: 726518677
This requires unrolling into two instructions sequences over n, since the largest
tcgen05.mma instructions can only handle n=256.
PiperOrigin-RevId: 726496900
It happens rarely, but this test seems to be flaky, probably because we don't
properly synchronize the memory accesses somehow. It's not important, so we
just avoid the data races now.
PiperOrigin-RevId: 726489606
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,
def f(x: int) -> str: ...
def g(x: int) -> str: ...
callback = f if ... else g # has type object!
Apparently it's a requirement that LLVM has, but it only complains about it when compiled with
assertions enabled, so it went unnoticed for a while.
PiperOrigin-RevId: 726468259
We only export symbols that being with `mlir` and a few other prefixes, so this renames our C API functions for consistency with that.
PiperOrigin-RevId: 726468092
This change is raising a better error because doing `NamedSharding(empty_mesh, P('x'))` will raise an error on construction but it is uglier than the current error added in this change.
PiperOrigin-RevId: 726253654