Previously, XLA's command buffers (CUDA graphs) would be disabled both
for PGLE profile collection and when re-compiling using the profile
data. With this change, they are only disabled when collecting the
profile data.
Now that we always use small tiles, we can lay out the tiled dimension
in arbitrary order so there's no need to swap them during the TMA.
PiperOrigin-RevId: 742206980
Imlpemented untile_ref and unswizzle_ref in order to allow patterns where we need different transform stacks over the same memref. For example we may want to reg->smem transposed, then smem->gmem sliced and maybe load strided/print in between for sanity checking:
```
# Store registers transposed
o_smem_swizzled = plgpu.unswizzle_ref(o_smem_raw, swizzle_out)
o_smem_t = o_smem_swizzled.reshape(1, 1, config.block_n, config.block_m)
o_smem_t = plgpu.untile_ref(o_smem_t, (n, m))
o_smem_t = plgpu.transpose_ref(o_smem_t, (1, 0))
o_smem_t[...] = plgpu.layout_cast((regs, plgpu.Layout.WGMMA_TRANSPOSED)
plgpu.commit_smem()
del o_smem_t
# Now we need different transforms on the same smem to slice and async-store to gmem
o_smem = o_smem_raw.reshape(n, m // swizzle_elems, swizzle_elems,)
o_smem = plgpu.unswizzle_ref(o_smem, swizzle_out)
o_smem = plgpu.tile_ref(o_smem, swizzle_out)
o_smem = o_smem.at[...]
plgpu.copy_smem_to_gmem(o_smem, o_ref.at[...],)
```
Which in turn lets us write
PiperOrigin-RevId: 742194519
This is part of the removal of support for large MMA tiling in Mosaic GPU.
It should also let us simplify some of the transpose transforms that are
no longer necessary, but I decided to separate this.
PiperOrigin-RevId: 742168801
Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576
This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated.
This PR also includes some fixes for getattr/setattr.
Copybara import of the project:
--
3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson <mattjj@google.com>:
[attrs] experimental appendattr
Merging this change closes#27576
COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8
PiperOrigin-RevId: 741662724
Specialize it to one shape per aval, since that's the only case that exists.
Remove some pointless assertions using this code.
PiperOrigin-RevId: 741569024