I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,
def f(x: int) -> str: ...
def g(x: int) -> str: ...
callback = f if ... else g # has type object!
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.
This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.
PiperOrigin-RevId: 726030853
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
This instruction is particularly useful for collective MMA, since it lets us
easily report on the progress of async copies from both blocks in the single
block that will be performing the MMA.
PiperOrigin-RevId: 725618793
The code for allocation is uninteresting and it's the only set of primitives
that is executed by a single warp (other TMA APIs have single-thread or
warpgroup issue granularity).
PiperOrigin-RevId: 725583720
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
It would be good to add smaller tests that verify reads and writes to TMEM,
since we depend on it here, but that will come later.
PiperOrigin-RevId: 724328602
This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.
As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
Hopper and Blackwell MMA instructions can share a lot of the same logic, which is
why I ended up splitting out a large fraction of WGMMA implementation into a common
utility. This should be an NFC for WGMMA, but it allows us to concisely implement
unrolling of MMAs of different sizes into a number of tcgen05.mma instructions.
PiperOrigin-RevId: 723544349
The previous example implementation loaded TMEM in a layout that was very hard to
efficiently store into SMEM or GMEM. With the new TMEMRef abstraction, we can implement
loads that yield a FragmentedArray with a new tiled layout that allows for efficient
swizzled stores to SMEM.
The new layout is very similar to the one we've been using for WGMMA on Hopper, only the
initial row tiling is increased to 128 (making each warp hold 32 rows, not 16 as previously).
PiperOrigin-RevId: 723506876
The previous impelmentation depends on LLVM intrinsics that have not been submitted
yet. This replaces them with inline PTX (as far as I can tell there's no downside to
that) that's encapsulated into convenience functions.
PiperOrigin-RevId: 723498248
There's no need to require extra arguments. This makes our calling convention
saner since the logical dimension order stays the same (e.g. for B it's always
k before n in the shape), only the in-memory representation changes.
Other than the API change, this is a NFC.
PiperOrigin-RevId: 723449720
* When current_mesh is Manual and aval mesh is Auto
* When current mesh is set and aval mesh is unset
* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.
* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.
This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.
`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.
PiperOrigin-RevId: 722868091