From 6f0737b46f24828fbede93477e2e83d7ed7f39d3 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 22 Mar 2024 16:58:45 -0700 Subject: [PATCH] [Pallas TPU] Add support for dynamic sized (tile aligned) DMAs This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example: ```python def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): size = size_smem_ref[0] pltpu.async_copy( x_hbm_ref.at[pl.ds(0, size)], o_hbm_ref.at[pl.ds(0, size)], sem).wait() ``` We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA. We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy. However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels. PiperOrigin-RevId: 618322737 --- jax/_src/pallas/mosaic/lowering.py | 44 ++++++++++--------- jax/_src/state/discharge.py | 5 ++- jax/_src/state/indexing.py | 38 ++++++++++++----- jax/_src/state/primitives.py | 31 ++++++++++---- jax/_src/state/types.py | 6 ++- tests/pallas/pallas_call_tpu_test.py | 64 ++++++++++++++++++++++++++++ 6 files changed, 147 insertions(+), 41 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f55f179a6..0069b0946 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -733,7 +733,7 @@ def _maybe_cast_to_index(cast_to_index, x): def _index_to_start_size_stride( idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool -) -> tuple[ir.Value, int, int, bool]: +) -> tuple[ir.Value, int | ir.Value, int, bool]: assert not isinstance(idx, slice) if isinstance(idx, indexing.Slice): start = _maybe_cast_to_index(cast_to_index, idx.start) @@ -762,7 +762,7 @@ def _indexer_to_start_size_stride( cast_to_index: bool, ) -> tuple[ tuple[ir.Value, ...], - tuple[int, ...], + tuple[int | ir.Value, ...], tuple[int, ...], tuple[bool, ...], tuple[int | pl_core.Mapped, ...], @@ -800,7 +800,7 @@ def _indexer_to_start_size_stride( def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, indexer: NDIndexer, ref_block_shape: tuple[int | pl_core.Mapped, ...] - ) -> tuple[ir.Value, state.AbstractRef, tuple[int | pl_core.Mapped, ...], + ) -> tuple[ir.Value, tuple[int | pl_core.Mapped, ...], tuple[int | pl_core.Mapped, ...]]: assert ref_block_shape is not None target_shape = indexer.get_indexer_shape() @@ -813,26 +813,28 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, ) if not all((s is None or s == 1) for s in strides): raise NotImplementedError("Strided slices of references are unsupported.") + dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value)) + ir_dynamic_size = ir.ShapedType.get_dynamic_size() + static_sizes = tuple(s if not isinstance(s, ir.Value) + else ir_dynamic_size for s in sizes) target_ref_ty = ir.MemRefType.get( - tuple(sizes), _dtype_to_ir_type(ref_aval.dtype), + static_sizes, _dtype_to_ir_type(ref_aval.dtype), memory_space=ref.type.memory_space) - inner_aval = ref_aval.inner_aval - out_aval = ref_aval.update(inner_aval=inner_aval.update(shape=target_shape)) - out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, []).result + out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result if any(squeeze_dims): # We need to squeeze out some dimensions squeezed_ref_ty = ir.MemRefType.get( tuple(target_shape), _dtype_to_ir_type(ref_aval.dtype), memory_space=ref.type.memory_space) out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result - return out, out_aval, ref_block_shape + return out, ref_block_shape def _index_ref(ref, ref_aval, ref_block_shape, indexers): for indexer in indexers: - ref, ref_aval, ref_block_shape = _slice_memref(ref, ref_aval, indexer, - ref_block_shape) - return ref, ref_aval, ref_block_shape + ref, ref_block_shape = _slice_memref(ref, ref_aval, indexer, + ref_block_shape) + return ref, ref_block_shape def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): @@ -846,7 +848,7 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): raise NotImplementedError ref_block_shape, *_ = ctx.block_shapes - ref, _, ref_block_shape = _index_ref( + ref, ref_block_shape = _index_ref( ref, ref_aval, ref_block_shape, slice_indexers) ref_type = ir.MemRefType(ref.type) is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space" @@ -900,7 +902,7 @@ def _masked_swap_lowering_rule( raise NotImplementedError ref_block_shape, *_ = ctx.block_shapes - ref, _, ref_block_shape = _index_ref( + ref, ref_block_shape = _index_ref( ref, ref_aval, ref_block_shape, slice_indexers) ref_type = ir.MemRefType(ref.type) @@ -2059,7 +2061,7 @@ def _semaphore_signal_lowering_rule( ): sem_aval, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, indexers, value, device_id = tree_util.tree_unflatten(args_tree, args) - sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) return tpu.SemaphoreSignalOp(sem, value, device_id=device_id).results @@ -2072,7 +2074,7 @@ lowering_rules[tpu_primitives.semaphore_signal_p] = ( def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, indexers, value = tree_util.tree_unflatten(args_tree, args) - sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) return tpu.SemaphoreWaitOp(sem, value).results lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule @@ -2094,16 +2096,16 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, ) block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2] - src_ref, _, _ = _index_ref( + src_ref, _ = _index_ref( src_ref, src_ref_aval, src_ref_block_shape, src_indexers ) if src_sem is not None: - src_sem, _, _ = _index_ref( + src_sem, _ = _index_ref( src_sem, src_sem_aval, src_sem_aval.shape, src_sem_indexers) - dst_ref, _, _ = _index_ref( + dst_ref, _ = _index_ref( dst_ref, dst_ref_aval, dst_ref_block_shape, dst_indexers ) - sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) + sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) return tpu.EnqueueDMAOp(src_ref, dst_ref, sem, source_semaphore=src_sem, @@ -2118,10 +2120,10 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in) block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) ref_block_shape = block_shapes[2] - ref, _, _ = _index_ref( + ref, _ = _index_ref( ref, ref_aval, ref_block_shape, indexers ) - sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) + sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) return tpu.WaitDMAOp(sem, ref).results lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index b7fdaa2b5..9cd9a9b86 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -188,7 +188,10 @@ def _convert_to_array_indexer(indexer: indexing.NDIndexer def _maybe_convert_to_dynamic_slice( indexer: indexing.NDIndexer, -) -> tuple[tuple[Array | int, ...], tuple[int, ...], tuple[int, ...]] | None: +) -> ( + tuple[tuple[Array | int, ...], tuple[Array | int, ...], tuple[int, ...]] + | None +): # An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice` # if each of the indexers is a `Slice` or a ()-shaped value. if not all(isinstance(i, indexing.Slice) or not np.shape(i) diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 60b99f100..716520c55 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -31,25 +31,39 @@ import numpy as np @dataclasses.dataclass class Slice: """Represents a slice with a dynamic start index and a fixed size.""" - start: Any - size: int + start: int | Array + size: int | Array stride: int = 1 def __post_init__(self): - if self.size < 0: - raise ValueError("`size` must not be negative.") if self.stride < 1: raise ValueError("`stride` must be >= 1.") + @property + def is_dynamic_start(self): + return not isinstance(self.start, int) + + @property + def is_dynamic_size(self): + return not isinstance(self.size, int) + def tree_flatten(self): # If `start` is statically known, we treat it as static information - if isinstance(self.start, int): - return (), (self.start, self.size, self.stride) - return (self.start,), (self.size, self.stride) + xs = () + data = () + xs += (self.start,) if self.is_dynamic_start else (None,) + data += (None,) if self.is_dynamic_start else (self.start,) + xs += (self.size,) if self.is_dynamic_size else (None,) + data += (None,) if self.is_dynamic_size else (self.size,) + data += (self.stride,) + return xs, data @classmethod def tree_unflatten(cls, aux_data, children) -> Slice: - return cls(*children, *aux_data) + start, size = [ + a if a is not None else b for a, b in zip(children, aux_data[:2]) + ] + return cls(start, size, aux_data[2]) @classmethod def from_slice(cls, slc: slice, size: int) -> Slice: @@ -61,7 +75,7 @@ class Slice: def dslice( start: int | Array | None, - size: int | None = None, + size: int | Array | None = None, stride: int | None = None, ) -> slice | Slice: """Constructs a `Slice` from a start and a size.""" @@ -154,6 +168,10 @@ class NDIndexer: f" {self.int_indexer_shape=}" ) from e + @property + def is_dynamic_size(self): + return any(isinstance(i, Slice) and i.is_dynamic_size for i in self.indices) + def tree_flatten(self): flat_idx, idx_tree = tree_util.tree_flatten(self.indices) return flat_idx, (idx_tree, self.shape, self.int_indexer_shape) @@ -202,7 +220,7 @@ class NDIndexer: indices = merge_lists(is_int_indexing, other_indexers, int_indexers) return NDIndexer(tuple(indices), shape, bcast_shape, validate=True) - def get_indexer_shape(self) -> tuple[int, ...]: + def get_indexer_shape(self) -> tuple[int | Array, ...]: _, slice_indexers, _ = unpack_ndindexer(self) slice_shape = [s.size for s in slice_indexers] # In NDIndexers, the int_indexer_shape is *always* at the front of the diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 4b55792f7..a11080cb8 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -148,11 +148,12 @@ def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> No def _shape_after_indexing( - shape: tuple[int, ...], indexers: tuple[indexing.NDIndexer, ...] -) -> tuple[int, ...]: + shape: tuple[int | Array, ...], indexers: tuple[indexing.NDIndexer, ...] +) -> tuple[int | Array, ...]: for indexer in indexers: # Run some simple checks that all the indexers have consistent shapes - assert indexer.shape == shape, (indexer.shape, shape) + if not indexer.is_dynamic_size: + assert indexer.shape == shape, (indexer.shape, shape) shape = indexer.get_indexer_shape() return shape @@ -239,12 +240,26 @@ def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice start, size = slc.start, slc.size if isinstance(start, core.Var): start_str = core.pp_var(start, context) - end_str = f'{start_str}+{size}' + size_str = ( + core.pp_var(size, context) + if isinstance(size, core.Var) + else str(size) + ) + return f'{start_str}:{start_str}+{size_str}' else: - start_str = '' if start == 0 else str(start) - end = start + size - end_str = '' if end == dim else str(end) - return f'{start_str}:{end_str}' + start_str = str(start) + if start == 0: + start_str = '' + if isinstance(size, core.Var): + size_str = core.pp_var(size, context) + if start_str: + return f'{start_str}:{start_str}+{size_str}' + else: + return f':{size_str}' + else: + end = start + size + end_str = '' if end == dim else str(end) + return f'{start_str}:{end_str}' def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer ) -> pp.Doc: diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 1665e0fed..303e4da0b 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -95,7 +95,11 @@ class RefView: indexers: tuple[indexing.NDIndexer, ...] @property - def shape(self) -> tuple[int, ...]: + def is_dynamic_size(self): + return self.indexers[-1].is_dynamic_size + + @property + def shape(self) -> tuple[int | Array, ...]: assert ( len(self.indexers) > 0 ), "Should not be able to create a trivial RefView" diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 6b32f3595..2655cf01f 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -2061,5 +2061,69 @@ class PallasCallPipelineTest(parameterized.TestCase): ) +class PallasCallDynamicDMATest(PallasTPUTest): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs not supported on TPU generations <= 3') + + def test_simple_tile_aligned_dynamic_size_dma(self): + + def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): + size = size_smem_ref[0] + pltpu.async_copy( + x_hbm_ref.at[pl.ds(0, size)], + o_hbm_ref.at[pl.ds(0, size)], sem).wait() + + x = jnp.tile(jnp.arange(8, dtype=jnp.int32)[:, None, None], [1, 8, 128]) + o = jnp.zeros((8, 8, 128), dtype=jnp.int32) + size = jnp.array([4], dtype=jnp.int32) + + out = pl.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY)], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + scratch_shapes=[pltpu.SemaphoreType.DMA] + ), + out_shape=o, + input_output_aliases={2: 0}, + )(size, x, o) + expected = o.at[:4].set(x.at[:4].get()) + np.testing.assert_array_equal(out, expected) + + def test_simple_dynamic_size_dma(self): + self.skipTest("doesn't work yet.") + def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): + size = size_smem_ref[0] + pltpu.async_copy( + x_hbm_ref.at[pl.ds(0, size)], + o_hbm_ref.at[pl.ds(0, size)], sem).wait() + + x = jnp.arange(8, dtype=jnp.int32) + o = jnp.zeros(8, dtype=jnp.int32) + size = jnp.array([4], dtype=jnp.int32) + + out = pl.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY)], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + scratch_shapes=[pltpu.SemaphoreType.DMA] + ), + out_shape=o, + input_output_aliases={2: 0}, + )(size, x, o) + expected = o.at[:4].set(x.at[:4].get()) + np.testing.assert_array_equal(out, expected) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())