Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.
PiperOrigin-RevId: 734269519
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.
The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.
PiperOrigin-RevId: 734157829
We don't have many Blackwell kernels yet, so let's begin the deprecation there!
Small tiles have clearer semantics when it comes to transposes too, which allows
us to enable more test cases.
PiperOrigin-RevId: 733786884
This makes the code path uniform for LHS/RHS and greatly clarifies the
magical computation of LBO/SBO. This change should make it significantly
easier for us to enable small tile support for the LHS.
PiperOrigin-RevId: 733737302
The original change was rolled back because there were real world use cases of custom_vjp where the fwd function had the wrong signature. To preserve backwards compatibility, we shouldn't resolve the input arguments to fwd using fwds signature. Instead, we can just ignore the signature because custom_vjp handles the resolution before we ever get here.
Reverts 1f3176636d304398b00a7d2cb0933859618affd8
PiperOrigin-RevId: 733643149
In this case, the example boils down to:
```
inp1 = f32[16@x, 4]
inp2 = f32[4]
def f(x: f32[4], y: f32[4])
return jnp.concat([x, y], axis=-1)
vmap(f, in_axes=(0, None))(inp1)
```
This example was breaking in concat batching rule because we didn't broadcast with the right sharding.
PiperOrigin-RevId: 733536944
Unfortunately, the old detection code doesn't guarantee that `epath` is
installed:
```
[utM] In [7]: importlib.util.find_spec("etils.epath")
Out[7]: ModuleSpec(name='etils.epath',
loader=<_frozen_importlib_external.SourceFileLoader object at
0x73b8492a7230>,
origin='/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath/__init__.py',
submodule_search_locations=['/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath'])
[utM] In [8]: import etils.epath
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent
call last)
Cell In[8], line 1
----> 1 import etils.epath
...
ModuleNotFoundError: No module named 'importlib_resources'
```
This happened every time I ran jax with a clean environment.
This is the ordering we want for a proper release of generic SMEM stores
into the async proxy. The old order was problematic: once the warpgroup
barrier was complete, some warps could get deselected before they get to
the fence. For as long as the first warp would make progress, it could go
through the fence along and start issuing TMA copies before other warps
have synchronized with the async proxy.
I have not observed this problem in any of our kernels so far, but this
order seems safer to me.
PiperOrigin-RevId: 733333814
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.
Fixes#26013.
PiperOrigin-RevId: 733293299
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.
We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.
Fixes https://github.com/jax-ml/jax/issues/26888
Now the small tiling is always `(8, swizzle // bytewidth(dtype))`, no matter whether the input
is transposed or not. This should simply the follow-up refactoring of the code and make it easier
to enable small tiling for LHS too.
PiperOrigin-RevId: 732933005
This tests saving a module with one set of axis names, but loading it with another set of axis names.
This does also test the custom calls:
- `@Sharding`
- `@xla.sdy.GlobalToLocalShape`
- `@xla.sdy.LocalToGlobalShape`
But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set `out_shardings` for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO.
PiperOrigin-RevId: 732893432
This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.
Note that we will be introducing some restrictions under Shardy for JAX export:
- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)
We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.
PiperOrigin-RevId: 732878916
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
] a
in (b,) }
```
instead of the previous:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
] a
in (b,) }
```
The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.
Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
that it more closely matches the CDF for low probably events (less than
2**-nmant).
Because -log(-log(x)) is more sensitive close to 1 than 0, we must use
-log(-logp1(-x)) instead to make better use of the extra range around 0.
PiperOrigin-RevId: 732757388
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.
PiperOrigin-RevId: 732618760