Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.
PiperOrigin-RevId: 734299259
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.
PiperOrigin-RevId: 734269519
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.
The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.
PiperOrigin-RevId: 734157829
This change allows us to get rid of extra env vars which used to control whether to install `jax` at head. Now, `jax` will be be built and consumed in the same way as the other wheels in the continuous jobs.
PiperOrigin-RevId: 734123590
We don't have many Blackwell kernels yet, so let's begin the deprecation there!
Small tiles have clearer semantics when it comes to transposes too, which allows
us to enable more test cases.
PiperOrigin-RevId: 733786884
This makes the code path uniform for LHS/RHS and greatly clarifies the
magical computation of LBO/SBO. This change should make it significantly
easier for us to enable small tile support for the LHS.
PiperOrigin-RevId: 733737302
The original change was rolled back because there were real world use cases of custom_vjp where the fwd function had the wrong signature. To preserve backwards compatibility, we shouldn't resolve the input arguments to fwd using fwds signature. Instead, we can just ignore the signature because custom_vjp handles the resolution before we ever get here.
Reverts 1f3176636d304398b00a7d2cb0933859618affd8
PiperOrigin-RevId: 733643149
In this case, the example boils down to:
```
inp1 = f32[16@x, 4]
inp2 = f32[4]
def f(x: f32[4], y: f32[4])
return jnp.concat([x, y], axis=-1)
vmap(f, in_axes=(0, None))(inp1)
```
This example was breaking in concat batching rule because we didn't broadcast with the right sharding.
PiperOrigin-RevId: 733536944
List of changes:
1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule.
2. Change the upload script to upload both rc and release tagged wheels (changes internal)
PiperOrigin-RevId: 733464219
Unfortunately, the old detection code doesn't guarantee that `epath` is
installed:
```
[utM] In [7]: importlib.util.find_spec("etils.epath")
Out[7]: ModuleSpec(name='etils.epath',
loader=<_frozen_importlib_external.SourceFileLoader object at
0x73b8492a7230>,
origin='/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath/__init__.py',
submodule_search_locations=['/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath'])
[utM] In [8]: import etils.epath
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent
call last)
Cell In[8], line 1
----> 1 import etils.epath
...
ModuleNotFoundError: No module named 'importlib_resources'
```
This happened every time I ran jax with a clean environment.