[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs

This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:

```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
  size = size_smem_ref[0]
  pltpu.async_copy(
    x_hbm_ref.at[pl.ds(0, size)],
    o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```

We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.

PiperOrigin-RevId: 618322737
This commit is contained in:
Sharad Vikram 2024-03-22 16:58:45 -07:00 committed by jax authors
parent 6ffd55c405
commit 6f0737b46f
6 changed files with 147 additions and 41 deletions

View File

@ -733,7 +733,7 @@ def _maybe_cast_to_index(cast_to_index, x):
def _index_to_start_size_stride(
idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool
) -> tuple[ir.Value, int, int, bool]:
) -> tuple[ir.Value, int | ir.Value, int, bool]:
assert not isinstance(idx, slice)
if isinstance(idx, indexing.Slice):
start = _maybe_cast_to_index(cast_to_index, idx.start)
@ -762,7 +762,7 @@ def _indexer_to_start_size_stride(
cast_to_index: bool,
) -> tuple[
tuple[ir.Value, ...],
tuple[int, ...],
tuple[int | ir.Value, ...],
tuple[int, ...],
tuple[bool, ...],
tuple[int | pl_core.Mapped, ...],
@ -800,7 +800,7 @@ def _indexer_to_start_size_stride(
def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
indexer: NDIndexer,
ref_block_shape: tuple[int | pl_core.Mapped, ...]
) -> tuple[ir.Value, state.AbstractRef, tuple[int | pl_core.Mapped, ...],
) -> tuple[ir.Value, tuple[int | pl_core.Mapped, ...],
tuple[int | pl_core.Mapped, ...]]:
assert ref_block_shape is not None
target_shape = indexer.get_indexer_shape()
@ -813,26 +813,28 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
)
if not all((s is None or s == 1) for s in strides):
raise NotImplementedError("Strided slices of references are unsupported.")
dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value))
ir_dynamic_size = ir.ShapedType.get_dynamic_size()
static_sizes = tuple(s if not isinstance(s, ir.Value)
else ir_dynamic_size for s in sizes)
target_ref_ty = ir.MemRefType.get(
tuple(sizes), _dtype_to_ir_type(ref_aval.dtype),
static_sizes, _dtype_to_ir_type(ref_aval.dtype),
memory_space=ref.type.memory_space)
inner_aval = ref_aval.inner_aval
out_aval = ref_aval.update(inner_aval=inner_aval.update(shape=target_shape))
out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, []).result
out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result
if any(squeeze_dims):
# We need to squeeze out some dimensions
squeezed_ref_ty = ir.MemRefType.get(
tuple(target_shape), _dtype_to_ir_type(ref_aval.dtype),
memory_space=ref.type.memory_space)
out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result
return out, out_aval, ref_block_shape
return out, ref_block_shape
def _index_ref(ref, ref_aval, ref_block_shape, indexers):
for indexer in indexers:
ref, ref_aval, ref_block_shape = _slice_memref(ref, ref_aval, indexer,
ref_block_shape)
return ref, ref_aval, ref_block_shape
ref, ref_block_shape = _slice_memref(ref, ref_aval, indexer,
ref_block_shape)
return ref, ref_block_shape
def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
@ -846,7 +848,7 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
raise NotImplementedError
ref_block_shape, *_ = ctx.block_shapes
ref, _, ref_block_shape = _index_ref(
ref, ref_block_shape = _index_ref(
ref, ref_aval, ref_block_shape, slice_indexers)
ref_type = ir.MemRefType(ref.type)
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
@ -900,7 +902,7 @@ def _masked_swap_lowering_rule(
raise NotImplementedError
ref_block_shape, *_ = ctx.block_shapes
ref, _, ref_block_shape = _index_ref(
ref, ref_block_shape = _index_ref(
ref, ref_aval, ref_block_shape, slice_indexers)
ref_type = ir.MemRefType(ref.type)
@ -2059,7 +2061,7 @@ def _semaphore_signal_lowering_rule(
):
sem_aval, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
sem, indexers, value, device_id = tree_util.tree_unflatten(args_tree, args)
sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
if device_id is not None:
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
return tpu.SemaphoreSignalOp(sem, value, device_id=device_id).results
@ -2072,7 +2074,7 @@ lowering_rules[tpu_primitives.semaphore_signal_p] = (
def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree):
sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
sem, indexers, value = tree_util.tree_unflatten(args_tree, args)
sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
return tpu.SemaphoreWaitOp(sem, value).results
lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule
@ -2094,16 +2096,16 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
)
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2]
src_ref, _, _ = _index_ref(
src_ref, _ = _index_ref(
src_ref, src_ref_aval, src_ref_block_shape, src_indexers
)
if src_sem is not None:
src_sem, _, _ = _index_ref(
src_sem, _ = _index_ref(
src_sem, src_sem_aval, src_sem_aval.shape, src_sem_indexers)
dst_ref, _, _ = _index_ref(
dst_ref, _ = _index_ref(
dst_ref, dst_ref_aval, dst_ref_block_shape, dst_indexers
)
sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
if device_id is not None:
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
return tpu.EnqueueDMAOp(src_ref, dst_ref, sem, source_semaphore=src_sem,
@ -2118,10 +2120,10 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in)
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
ref_block_shape = block_shapes[2]
ref, _, _ = _index_ref(
ref, _ = _index_ref(
ref, ref_aval, ref_block_shape, indexers
)
sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
return tpu.WaitDMAOp(sem, ref).results
lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule

View File

@ -188,7 +188,10 @@ def _convert_to_array_indexer(indexer: indexing.NDIndexer
def _maybe_convert_to_dynamic_slice(
indexer: indexing.NDIndexer,
) -> tuple[tuple[Array | int, ...], tuple[int, ...], tuple[int, ...]] | None:
) -> (
tuple[tuple[Array | int, ...], tuple[Array | int, ...], tuple[int, ...]]
| None
):
# An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice`
# if each of the indexers is a `Slice` or a ()-shaped value.
if not all(isinstance(i, indexing.Slice) or not np.shape(i)

View File

@ -31,25 +31,39 @@ import numpy as np
@dataclasses.dataclass
class Slice:
"""Represents a slice with a dynamic start index and a fixed size."""
start: Any
size: int
start: int | Array
size: int | Array
stride: int = 1
def __post_init__(self):
if self.size < 0:
raise ValueError("`size` must not be negative.")
if self.stride < 1:
raise ValueError("`stride` must be >= 1.")
@property
def is_dynamic_start(self):
return not isinstance(self.start, int)
@property
def is_dynamic_size(self):
return not isinstance(self.size, int)
def tree_flatten(self):
# If `start` is statically known, we treat it as static information
if isinstance(self.start, int):
return (), (self.start, self.size, self.stride)
return (self.start,), (self.size, self.stride)
xs = ()
data = ()
xs += (self.start,) if self.is_dynamic_start else (None,)
data += (None,) if self.is_dynamic_start else (self.start,)
xs += (self.size,) if self.is_dynamic_size else (None,)
data += (None,) if self.is_dynamic_size else (self.size,)
data += (self.stride,)
return xs, data
@classmethod
def tree_unflatten(cls, aux_data, children) -> Slice:
return cls(*children, *aux_data)
start, size = [
a if a is not None else b for a, b in zip(children, aux_data[:2])
]
return cls(start, size, aux_data[2])
@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
@ -61,7 +75,7 @@ class Slice:
def dslice(
start: int | Array | None,
size: int | None = None,
size: int | Array | None = None,
stride: int | None = None,
) -> slice | Slice:
"""Constructs a `Slice` from a start and a size."""
@ -154,6 +168,10 @@ class NDIndexer:
f" {self.int_indexer_shape=}"
) from e
@property
def is_dynamic_size(self):
return any(isinstance(i, Slice) and i.is_dynamic_size for i in self.indices)
def tree_flatten(self):
flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)
@ -202,7 +220,7 @@ class NDIndexer:
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape, validate=True)
def get_indexer_shape(self) -> tuple[int, ...]:
def get_indexer_shape(self) -> tuple[int | Array, ...]:
_, slice_indexers, _ = unpack_ndindexer(self)
slice_shape = [s.size for s in slice_indexers]
# In NDIndexers, the int_indexer_shape is *always* at the front of the

View File

@ -148,11 +148,12 @@ def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> No
def _shape_after_indexing(
shape: tuple[int, ...], indexers: tuple[indexing.NDIndexer, ...]
) -> tuple[int, ...]:
shape: tuple[int | Array, ...], indexers: tuple[indexing.NDIndexer, ...]
) -> tuple[int | Array, ...]:
for indexer in indexers:
# Run some simple checks that all the indexers have consistent shapes
assert indexer.shape == shape, (indexer.shape, shape)
if not indexer.is_dynamic_size:
assert indexer.shape == shape, (indexer.shape, shape)
shape = indexer.get_indexer_shape()
return shape
@ -239,12 +240,26 @@ def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice
start, size = slc.start, slc.size
if isinstance(start, core.Var):
start_str = core.pp_var(start, context)
end_str = f'{start_str}+{size}'
size_str = (
core.pp_var(size, context)
if isinstance(size, core.Var)
else str(size)
)
return f'{start_str}:{start_str}+{size_str}'
else:
start_str = '' if start == 0 else str(start)
end = start + size
end_str = '' if end == dim else str(end)
return f'{start_str}:{end_str}'
start_str = str(start)
if start == 0:
start_str = ''
if isinstance(size, core.Var):
size_str = core.pp_var(size, context)
if start_str:
return f'{start_str}:{start_str}+{size_str}'
else:
return f':{size_str}'
else:
end = start + size
end_str = '' if end == dim else str(end)
return f'{start_str}:{end_str}'
def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer
) -> pp.Doc:

View File

@ -95,7 +95,11 @@ class RefView:
indexers: tuple[indexing.NDIndexer, ...]
@property
def shape(self) -> tuple[int, ...]:
def is_dynamic_size(self):
return self.indexers[-1].is_dynamic_size
@property
def shape(self) -> tuple[int | Array, ...]:
assert (
len(self.indexers) > 0
), "Should not be able to create a trivial RefView"

View File

@ -2061,5 +2061,69 @@ class PallasCallPipelineTest(parameterized.TestCase):
)
class PallasCallDynamicDMATest(PallasTPUTest):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest('DMAs not supported on TPU generations <= 3')
def test_simple_tile_aligned_dynamic_size_dma(self):
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
x = jnp.tile(jnp.arange(8, dtype=jnp.int32)[:, None, None], [1, 8, 128])
o = jnp.zeros((8, 8, 128), dtype=jnp.int32)
size = jnp.array([4], dtype=jnp.int32)
out = pl.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM),
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.ANY)],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
scratch_shapes=[pltpu.SemaphoreType.DMA]
),
out_shape=o,
input_output_aliases={2: 0},
)(size, x, o)
expected = o.at[:4].set(x.at[:4].get())
np.testing.assert_array_equal(out, expected)
def test_simple_dynamic_size_dma(self):
self.skipTest("doesn't work yet.")
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
x = jnp.arange(8, dtype=jnp.int32)
o = jnp.zeros(8, dtype=jnp.int32)
size = jnp.array([4], dtype=jnp.int32)
out = pl.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM),
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.ANY)],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
scratch_shapes=[pltpu.SemaphoreType.DMA]
),
out_shape=o,
input_output_aliases={2: 0},
)(size, x, o)
expected = o.at[:4].set(x.at[:4].get())
np.testing.assert_array_equal(out, expected)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())