Add a first benchmark for tracing/lowering pallas splash attention.
Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan.
---------------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------------
test_pallas_mqa_splash_attention_trace 39.8 ms 39.8 ms 19
test_pallas_mqa_splash_attention_lower 42.1 ms 41.9 ms 18
PiperOrigin-RevId: 742259409
Now that we always use small tiles, we can lay out the tiled dimension
in arbitrary order so there's no need to swap them during the TMA.
PiperOrigin-RevId: 742206980
Imlpemented untile_ref and unswizzle_ref in order to allow patterns where we need different transform stacks over the same memref. For example we may want to reg->smem transposed, then smem->gmem sliced and maybe load strided/print in between for sanity checking:
```
# Store registers transposed
o_smem_swizzled = plgpu.unswizzle_ref(o_smem_raw, swizzle_out)
o_smem_t = o_smem_swizzled.reshape(1, 1, config.block_n, config.block_m)
o_smem_t = plgpu.untile_ref(o_smem_t, (n, m))
o_smem_t = plgpu.transpose_ref(o_smem_t, (1, 0))
o_smem_t[...] = plgpu.layout_cast((regs, plgpu.Layout.WGMMA_TRANSPOSED)
plgpu.commit_smem()
del o_smem_t
# Now we need different transforms on the same smem to slice and async-store to gmem
o_smem = o_smem_raw.reshape(n, m // swizzle_elems, swizzle_elems,)
o_smem = plgpu.unswizzle_ref(o_smem, swizzle_out)
o_smem = plgpu.tile_ref(o_smem, swizzle_out)
o_smem = o_smem.at[...]
plgpu.copy_smem_to_gmem(o_smem, o_ref.at[...],)
```
Which in turn lets us write
PiperOrigin-RevId: 742194519
This is part of the removal of support for large MMA tiling in Mosaic GPU.
It should also let us simplify some of the transpose transforms that are
no longer necessary, but I decided to separate this.
PiperOrigin-RevId: 742168801
The GPU-specific deps were added to the backend-independent tests by mistake [here](https://github.com/jax-ml/jax/pull/27113). These tests should pass using `jax` and `jaxlib` wheels only.
PiperOrigin-RevId: 741663266
Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576
This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated.
This PR also includes some fixes for getattr/setattr.
Copybara import of the project:
--
3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson <mattjj@google.com>:
[attrs] experimental appendattr
Merging this change closes#27576
COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8
PiperOrigin-RevId: 741662724
We use f32 as the dtype inside the kernel. Before we write the result from vmem to hbm, we convert to the desired dtype (eg bf16). So we can save memory bandwidth.
Also, made minor change by checking sliding window and logit soft capping in the function that checks the static value.
PiperOrigin-RevId: 741660728
Specialize it to one shape per aval, since that's the only case that exists.
Remove some pointless assertions using this code.
PiperOrigin-RevId: 741569024
This is a precautionary measure to prevent conflicts with other packages
using nanobind and registering the same types. We don't want JAX's
nanobind registrations to conflict on, say, XLA types with other
projects.