2024-01-02 15:52:57 -08:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Contains shared logic and abstractions for Pallas indexing ops."""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import dataclasses
|
2025-02-12 13:27:29 +00:00
|
|
|
from typing import Any, Sequence, Union
|
2024-01-02 15:52:57 -08:00
|
|
|
|
|
|
|
from jax._src import core
|
|
|
|
from jax._src import tree_util
|
|
|
|
from jax._src.typing import Array
|
|
|
|
from jax._src.util import merge_lists
|
|
|
|
from jax._src.util import partition_list
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class Slice:
|
2024-05-29 16:20:01 +01:00
|
|
|
"""A slice with a start index and a size.
|
|
|
|
|
|
|
|
Both start index and size can either be static, i.e. known at tracing
|
|
|
|
and compilation time, or dynamic.
|
|
|
|
"""
|
|
|
|
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
start: int | Array
|
|
|
|
size: int | Array
|
2024-03-14 16:31:23 -07:00
|
|
|
stride: int = 1
|
2024-01-02 15:52:57 -08:00
|
|
|
|
|
|
|
def __post_init__(self):
|
2024-03-14 16:31:23 -07:00
|
|
|
if self.stride < 1:
|
|
|
|
raise ValueError("`stride` must be >= 1.")
|
2024-01-02 15:52:57 -08:00
|
|
|
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
@property
|
|
|
|
def is_dynamic_start(self):
|
2024-11-20 20:50:37 -08:00
|
|
|
return not core.is_dim(self.start)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
|
|
|
|
@property
|
|
|
|
def is_dynamic_size(self):
|
2024-11-20 20:50:37 -08:00
|
|
|
return not core.is_dim(self.size)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
|
2024-01-02 15:52:57 -08:00
|
|
|
def tree_flatten(self):
|
|
|
|
# If `start` is statically known, we treat it as static information
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
xs = ()
|
|
|
|
data = ()
|
|
|
|
xs += (self.start,) if self.is_dynamic_start else (None,)
|
|
|
|
data += (None,) if self.is_dynamic_start else (self.start,)
|
|
|
|
xs += (self.size,) if self.is_dynamic_size else (None,)
|
|
|
|
data += (None,) if self.is_dynamic_size else (self.size,)
|
|
|
|
data += (self.stride,)
|
|
|
|
return xs, data
|
2024-01-02 15:52:57 -08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tree_unflatten(cls, aux_data, children) -> Slice:
|
2024-06-26 14:44:52 -04:00
|
|
|
start, size = (
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
a if a is not None else b for a, b in zip(children, aux_data[:2])
|
2024-06-26 14:44:52 -04:00
|
|
|
)
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
return cls(start, size, aux_data[2])
|
2024-01-02 15:52:57 -08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_slice(cls, slc: slice, size: int) -> Slice:
|
2024-11-20 20:50:37 -08:00
|
|
|
start, step, size = core.canonicalize_slice(slc, size)
|
2024-03-14 16:31:23 -07:00
|
|
|
if step < 1:
|
|
|
|
raise ValueError(f"slice must have a step >= 1 (found: {step})")
|
2024-11-20 20:50:37 -08:00
|
|
|
return cls(start, size, step)
|
2024-01-02 15:52:57 -08:00
|
|
|
|
|
|
|
|
2024-03-14 16:31:23 -07:00
|
|
|
def dslice(
|
|
|
|
start: int | Array | None,
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
size: int | Array | None = None,
|
2024-03-14 16:31:23 -07:00
|
|
|
stride: int | None = None,
|
|
|
|
) -> slice | Slice:
|
2024-05-29 16:20:01 +01:00
|
|
|
"""Constructs a ``Slice`` from a start index and a size.
|
|
|
|
|
|
|
|
The semantics of ``dslice`` mirror those of the builtin ``slice`` type:
|
|
|
|
|
|
|
|
* ``dslice(None)`` is ``:``
|
|
|
|
* ``dslice(j)`` is ``:j``
|
|
|
|
* ``dslice(i, j)`` is ``i:i+j``
|
|
|
|
* ``dslice(i, j, stride)`` is ``i:i+j:stride``
|
|
|
|
"""
|
2024-01-02 15:52:57 -08:00
|
|
|
if start is None:
|
|
|
|
return slice(None)
|
2024-03-14 16:31:23 -07:00
|
|
|
if stride is None:
|
|
|
|
stride = 1
|
|
|
|
if not isinstance(stride, int):
|
|
|
|
raise ValueError("Non-static stride in `dslice`")
|
2024-01-02 15:52:57 -08:00
|
|
|
if size is None:
|
|
|
|
if not isinstance(start, int):
|
|
|
|
raise ValueError("Non-static `dslice`")
|
2024-03-14 16:31:23 -07:00
|
|
|
return Slice(0, start, stride)
|
|
|
|
return Slice(start, size, stride)
|
|
|
|
|
|
|
|
|
2024-01-02 15:52:57 -08:00
|
|
|
ds = dslice # Handy alias
|
|
|
|
|
|
|
|
|
|
|
|
IntIndexer = Union[int, Array]
|
|
|
|
DimIndexer = Union[IntIndexer, Slice]
|
|
|
|
|
|
|
|
def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
|
|
|
|
tuple[Slice, ...],
|
|
|
|
tuple[IntIndexer, ...]]:
|
|
|
|
is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices]
|
|
|
|
slice_indexers, int_indexers = partition_list(
|
|
|
|
is_int_indexing, indexer.indices)
|
|
|
|
return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore
|
|
|
|
|
|
|
|
def _maybe_concretize(x: Any):
|
2024-01-09 11:32:00 -08:00
|
|
|
# This is roughly the same logic as core.concrete_or_error, but we avoid
|
|
|
|
# calling that because constructing the ConcretizationTypeError can be
|
|
|
|
# expensive as the size of the tracing context (i.e. the jaxpr) grows.
|
2024-10-31 14:06:08 -07:00
|
|
|
return core.to_concrete_value(x)
|
2024-01-02 15:52:57 -08:00
|
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class NDIndexer:
|
|
|
|
indices: tuple[DimIndexer, ...]
|
|
|
|
shape: tuple[int, ...]
|
|
|
|
int_indexer_shape: tuple[int, ...]
|
2024-05-29 16:20:01 +01:00
|
|
|
# Off by default to avoid doing validation during pytree operations.
|
2024-01-02 15:52:57 -08:00
|
|
|
validate: bool = False
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
if not self.validate:
|
|
|
|
return
|
|
|
|
if len(self.indices) != len(self.shape):
|
|
|
|
raise ValueError(
|
|
|
|
f"`indices` must be the same length as `Ref` shape.: {self}."
|
|
|
|
)
|
|
|
|
# We validate integer indexing shapes here
|
|
|
|
for idx, s in zip(self.indices, self.shape):
|
|
|
|
if isinstance(idx, Slice):
|
|
|
|
start = idx.start
|
|
|
|
if value := _maybe_concretize(start):
|
|
|
|
if value >= s:
|
|
|
|
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
|
2024-04-23 18:01:20 -07:00
|
|
|
if size := _maybe_concretize(idx.size):
|
|
|
|
if value + (size - 1) * idx.stride >= s:
|
|
|
|
raise ValueError(
|
|
|
|
f"Out of bound slice: start={value}, size={size},"
|
|
|
|
f" stride={idx.stride}, dim={s}."
|
|
|
|
)
|
2024-01-02 15:52:57 -08:00
|
|
|
continue
|
|
|
|
# The shape of indexer integers should be broadcastable up to the
|
|
|
|
# int_indexer_shape of the whole NDIndexer
|
|
|
|
if not np.shape(idx):
|
|
|
|
if (value := _maybe_concretize(idx)) and value >= s:
|
|
|
|
raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.")
|
|
|
|
# For ()-shaped indexers, we can broadcast no problm.
|
|
|
|
continue
|
|
|
|
# If we don't have a ()-shaped indexer, the rank must match
|
|
|
|
# int_indexer_shape
|
|
|
|
if np.ndim(idx) != len(self.int_indexer_shape):
|
|
|
|
raise ValueError(
|
|
|
|
f"Indexer must have rank {np.ndim(idx)}: {idx=} vs."
|
|
|
|
f" {self.int_indexer_shape=}"
|
|
|
|
)
|
|
|
|
# Here we check that the shapes broadcast.
|
|
|
|
try:
|
|
|
|
np.broadcast_shapes(np.shape(idx), self.int_indexer_shape)
|
|
|
|
except ValueError as e:
|
|
|
|
raise ValueError(
|
|
|
|
f"Could not broadcast integer indexer: {idx=} vs."
|
|
|
|
f" {self.int_indexer_shape=}"
|
|
|
|
) from e
|
|
|
|
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
@property
|
|
|
|
def is_dynamic_size(self):
|
|
|
|
return any(isinstance(i, Slice) and i.is_dynamic_size for i in self.indices)
|
|
|
|
|
2024-01-02 15:52:57 -08:00
|
|
|
def tree_flatten(self):
|
|
|
|
flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
|
|
|
|
return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tree_unflatten(cls, data, flat_idx):
|
|
|
|
idx_tree, shape, int_indexer_shape = data
|
|
|
|
indices = tree_util.tree_unflatten(idx_tree, flat_idx)
|
2024-05-29 16:20:01 +01:00
|
|
|
return cls(tuple(indices), shape, int_indexer_shape)
|
2024-01-02 15:52:57 -08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_indices_shape(cls, indices, shape) -> NDIndexer:
|
|
|
|
if not isinstance(indices, tuple):
|
2024-05-29 16:20:01 +01:00
|
|
|
# TODO(slebedev): Consider requiring `indices` to be a Sequence.
|
2024-01-02 15:52:57 -08:00
|
|
|
indices = (indices,)
|
2024-05-29 16:20:01 +01:00
|
|
|
|
|
|
|
if num_ellipsis := sum(idx is ... for idx in indices):
|
2024-05-17 16:57:05 -07:00
|
|
|
if num_ellipsis > 1:
|
|
|
|
raise ValueError("Only one ellipsis is supported.")
|
2024-05-29 16:20:01 +01:00
|
|
|
# Expand ... so that `indices` has the same length as `shape`.
|
|
|
|
ip = indices.index(...)
|
2025-02-12 13:27:29 +00:00
|
|
|
indices = list(indices)
|
2024-05-29 16:20:01 +01:00
|
|
|
indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1)
|
|
|
|
indices = tuple(indices)
|
2025-02-12 13:27:29 +00:00
|
|
|
if len(indices) > len(shape):
|
2024-01-02 15:52:57 -08:00
|
|
|
raise ValueError("`indices` must not be longer than `shape`: "
|
|
|
|
f"{indices=}, {shape=}")
|
2024-05-29 16:20:01 +01:00
|
|
|
elif len(indices) < len(shape):
|
|
|
|
# Pad `indices` to have the same length as `shape`.
|
2025-02-12 13:27:29 +00:00
|
|
|
indices = (*indices, *[slice(None)] * (len(shape) - len(indices)))
|
2024-05-29 16:20:01 +01:00
|
|
|
|
|
|
|
# Promote all builtin `slice`s to `Slice`.
|
|
|
|
indices = tuple(
|
|
|
|
Slice.from_slice(i, s) if isinstance(i, slice) else i
|
|
|
|
for i, s in zip(indices, shape))
|
|
|
|
|
2024-01-02 15:52:57 -08:00
|
|
|
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
2024-05-29 16:20:01 +01:00
|
|
|
if any(is_int_indexing):
|
2025-02-12 13:27:29 +00:00
|
|
|
int_indexers: Sequence[Any]
|
2024-05-29 16:20:01 +01:00
|
|
|
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
|
|
|
indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers)
|
2024-01-02 15:52:57 -08:00
|
|
|
try:
|
2024-05-29 16:20:01 +01:00
|
|
|
int_indexer_shape = np.broadcast_shapes(*indexer_shapes)
|
2024-01-02 15:52:57 -08:00
|
|
|
except ValueError as e:
|
|
|
|
# Raise a nicer error than the NumPy one.
|
2024-05-29 16:20:01 +01:00
|
|
|
raise ValueError(
|
|
|
|
f"Cannot broadcast shapes for indexing: {indexer_shapes}") from e
|
|
|
|
|
|
|
|
# Here we use the `broadcast_to` primitive instead of composing lax
|
|
|
|
# primitives together because it is easier to lower in targets like
|
|
|
|
# Triton/Mosaic.
|
|
|
|
#
|
|
|
|
# The local import avoids a circular dependency between primitives
|
|
|
|
# and this module.
|
|
|
|
from jax._src.state import primitives as sp # pytype: disable=import-error
|
|
|
|
int_indexers = [
|
|
|
|
sp.broadcast_to(i, int_indexer_shape) for i in int_indexers
|
|
|
|
]
|
|
|
|
indices = tuple(merge_lists(is_int_indexing, other_indexers, int_indexers))
|
2024-01-02 15:52:57 -08:00
|
|
|
else:
|
2024-05-29 16:20:01 +01:00
|
|
|
int_indexer_shape = ()
|
|
|
|
|
|
|
|
return cls(indices, shape, int_indexer_shape, validate=True)
|
2024-01-02 15:52:57 -08:00
|
|
|
|
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
2024-03-22 16:58:45 -07:00
|
|
|
def get_indexer_shape(self) -> tuple[int | Array, ...]:
|
2024-01-02 15:52:57 -08:00
|
|
|
_, slice_indexers, _ = unpack_ndindexer(self)
|
|
|
|
slice_shape = [s.size for s in slice_indexers]
|
|
|
|
# In NDIndexers, the int_indexer_shape is *always* at the front of the
|
|
|
|
# result.
|
|
|
|
return (*self.int_indexer_shape, *slice_shape)
|
2024-10-01 03:30:15 -07:00
|
|
|
|
|
|
|
def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]:
|
|
|
|
del shape # Unused
|
|
|
|
return self.get_indexer_shape()
|
|
|
|
|
|
|
|
def transform_dtype(self, dtype):
|
|
|
|
return dtype
|