Sharad Vikram 6f0737b46f [Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:

```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
  size = size_smem_ref[0]
  pltpu.async_copy(
    x_hbm_ref.at[pl.ds(0, size)],
    o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```

We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.

PiperOrigin-RevId: 618322737
2024-03-22 16:59:32 -07:00
..
2024-01-27 17:44:43 -08:00