My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX.
It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future.
This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that.
PiperOrigin-RevId: 735381736
Surprisingly, the bug was tracked down to #26111 aka cl/730939406, specifically
the new implementation of reset_name_stack in source_info_util.py.
To repro, use the before-this-commit implementation of reset_name_stack (left
commented-out in the file), and run
```
JAX_USE_DIRECT_LINEARIZE=1 python tests/name_stack_test.py NameStackTransformationTest.test_nested_jit_stack
```
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.
PiperOrigin-RevId: 734553668
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)
Also make some formatting changes in scan lowering to make it easier to debug.
PiperOrigin-RevId: 734542862
The difficulty here is that our register tiling is based on the (64, 8)
shape, while the memory tiling is now (8, swizzle // bytewidth). Before,
we would assume that each register tile fits neatly within a single
memory tile, but now it is obviously not the case. Luckily, it wasn't
too hard to add.
PiperOrigin-RevId: 734517000
Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.
PiperOrigin-RevId: 734299259
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.
PiperOrigin-RevId: 734269519
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.
The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.
PiperOrigin-RevId: 734157829
This change allows us to get rid of extra env vars which used to control whether to install `jax` at head. Now, `jax` will be be built and consumed in the same way as the other wheels in the continuous jobs.
PiperOrigin-RevId: 734123590