mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example: ```python def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): size = size_smem_ref[0] pltpu.async_copy( x_hbm_ref.at[pl.ds(0, size)], o_hbm_ref.at[pl.ds(0, size)], sem).wait() ``` We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA. We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy. However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels. PiperOrigin-RevId: 618322737
This commit is contained in:
parent
6ffd55c405
commit
6f0737b46f
@ -733,7 +733,7 @@ def _maybe_cast_to_index(cast_to_index, x):
|
||||
|
||||
def _index_to_start_size_stride(
|
||||
idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool
|
||||
) -> tuple[ir.Value, int, int, bool]:
|
||||
) -> tuple[ir.Value, int | ir.Value, int, bool]:
|
||||
assert not isinstance(idx, slice)
|
||||
if isinstance(idx, indexing.Slice):
|
||||
start = _maybe_cast_to_index(cast_to_index, idx.start)
|
||||
@ -762,7 +762,7 @@ def _indexer_to_start_size_stride(
|
||||
cast_to_index: bool,
|
||||
) -> tuple[
|
||||
tuple[ir.Value, ...],
|
||||
tuple[int, ...],
|
||||
tuple[int | ir.Value, ...],
|
||||
tuple[int, ...],
|
||||
tuple[bool, ...],
|
||||
tuple[int | pl_core.Mapped, ...],
|
||||
@ -800,7 +800,7 @@ def _indexer_to_start_size_stride(
|
||||
def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
|
||||
indexer: NDIndexer,
|
||||
ref_block_shape: tuple[int | pl_core.Mapped, ...]
|
||||
) -> tuple[ir.Value, state.AbstractRef, tuple[int | pl_core.Mapped, ...],
|
||||
) -> tuple[ir.Value, tuple[int | pl_core.Mapped, ...],
|
||||
tuple[int | pl_core.Mapped, ...]]:
|
||||
assert ref_block_shape is not None
|
||||
target_shape = indexer.get_indexer_shape()
|
||||
@ -813,26 +813,28 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
|
||||
)
|
||||
if not all((s is None or s == 1) for s in strides):
|
||||
raise NotImplementedError("Strided slices of references are unsupported.")
|
||||
dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value))
|
||||
ir_dynamic_size = ir.ShapedType.get_dynamic_size()
|
||||
static_sizes = tuple(s if not isinstance(s, ir.Value)
|
||||
else ir_dynamic_size for s in sizes)
|
||||
target_ref_ty = ir.MemRefType.get(
|
||||
tuple(sizes), _dtype_to_ir_type(ref_aval.dtype),
|
||||
static_sizes, _dtype_to_ir_type(ref_aval.dtype),
|
||||
memory_space=ref.type.memory_space)
|
||||
inner_aval = ref_aval.inner_aval
|
||||
out_aval = ref_aval.update(inner_aval=inner_aval.update(shape=target_shape))
|
||||
out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, []).result
|
||||
out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result
|
||||
if any(squeeze_dims):
|
||||
# We need to squeeze out some dimensions
|
||||
squeezed_ref_ty = ir.MemRefType.get(
|
||||
tuple(target_shape), _dtype_to_ir_type(ref_aval.dtype),
|
||||
memory_space=ref.type.memory_space)
|
||||
out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result
|
||||
return out, out_aval, ref_block_shape
|
||||
return out, ref_block_shape
|
||||
|
||||
|
||||
def _index_ref(ref, ref_aval, ref_block_shape, indexers):
|
||||
for indexer in indexers:
|
||||
ref, ref_aval, ref_block_shape = _slice_memref(ref, ref_aval, indexer,
|
||||
ref_block_shape)
|
||||
return ref, ref_aval, ref_block_shape
|
||||
ref, ref_block_shape = _slice_memref(ref, ref_aval, indexer,
|
||||
ref_block_shape)
|
||||
return ref, ref_block_shape
|
||||
|
||||
|
||||
def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
|
||||
@ -846,7 +848,7 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
|
||||
raise NotImplementedError
|
||||
|
||||
ref_block_shape, *_ = ctx.block_shapes
|
||||
ref, _, ref_block_shape = _index_ref(
|
||||
ref, ref_block_shape = _index_ref(
|
||||
ref, ref_aval, ref_block_shape, slice_indexers)
|
||||
ref_type = ir.MemRefType(ref.type)
|
||||
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
|
||||
@ -900,7 +902,7 @@ def _masked_swap_lowering_rule(
|
||||
raise NotImplementedError
|
||||
|
||||
ref_block_shape, *_ = ctx.block_shapes
|
||||
ref, _, ref_block_shape = _index_ref(
|
||||
ref, ref_block_shape = _index_ref(
|
||||
ref, ref_aval, ref_block_shape, slice_indexers)
|
||||
|
||||
ref_type = ir.MemRefType(ref.type)
|
||||
@ -2059,7 +2061,7 @@ def _semaphore_signal_lowering_rule(
|
||||
):
|
||||
sem_aval, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
||||
sem, indexers, value, device_id = tree_util.tree_unflatten(args_tree, args)
|
||||
sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
|
||||
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
|
||||
if device_id is not None:
|
||||
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
|
||||
return tpu.SemaphoreSignalOp(sem, value, device_id=device_id).results
|
||||
@ -2072,7 +2074,7 @@ lowering_rules[tpu_primitives.semaphore_signal_p] = (
|
||||
def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree):
|
||||
sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
|
||||
sem, indexers, value = tree_util.tree_unflatten(args_tree, args)
|
||||
sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
|
||||
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
|
||||
return tpu.SemaphoreWaitOp(sem, value).results
|
||||
lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule
|
||||
|
||||
@ -2094,16 +2096,16 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
||||
)
|
||||
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
|
||||
src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2]
|
||||
src_ref, _, _ = _index_ref(
|
||||
src_ref, _ = _index_ref(
|
||||
src_ref, src_ref_aval, src_ref_block_shape, src_indexers
|
||||
)
|
||||
if src_sem is not None:
|
||||
src_sem, _, _ = _index_ref(
|
||||
src_sem, _ = _index_ref(
|
||||
src_sem, src_sem_aval, src_sem_aval.shape, src_sem_indexers)
|
||||
dst_ref, _, _ = _index_ref(
|
||||
dst_ref, _ = _index_ref(
|
||||
dst_ref, dst_ref_aval, dst_ref_block_shape, dst_indexers
|
||||
)
|
||||
sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
|
||||
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
|
||||
if device_id is not None:
|
||||
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
|
||||
return tpu.EnqueueDMAOp(src_ref, dst_ref, sem, source_semaphore=src_sem,
|
||||
@ -2118,10 +2120,10 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
||||
sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in)
|
||||
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
|
||||
ref_block_shape = block_shapes[2]
|
||||
ref, _, _ = _index_ref(
|
||||
ref, _ = _index_ref(
|
||||
ref, ref_aval, ref_block_shape, indexers
|
||||
)
|
||||
sem, _, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
|
||||
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers)
|
||||
return tpu.WaitDMAOp(sem, ref).results
|
||||
lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule
|
||||
|
||||
|
@ -188,7 +188,10 @@ def _convert_to_array_indexer(indexer: indexing.NDIndexer
|
||||
|
||||
def _maybe_convert_to_dynamic_slice(
|
||||
indexer: indexing.NDIndexer,
|
||||
) -> tuple[tuple[Array | int, ...], tuple[int, ...], tuple[int, ...]] | None:
|
||||
) -> (
|
||||
tuple[tuple[Array | int, ...], tuple[Array | int, ...], tuple[int, ...]]
|
||||
| None
|
||||
):
|
||||
# An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice`
|
||||
# if each of the indexers is a `Slice` or a ()-shaped value.
|
||||
if not all(isinstance(i, indexing.Slice) or not np.shape(i)
|
||||
|
@ -31,25 +31,39 @@ import numpy as np
|
||||
@dataclasses.dataclass
|
||||
class Slice:
|
||||
"""Represents a slice with a dynamic start index and a fixed size."""
|
||||
start: Any
|
||||
size: int
|
||||
start: int | Array
|
||||
size: int | Array
|
||||
stride: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.size < 0:
|
||||
raise ValueError("`size` must not be negative.")
|
||||
if self.stride < 1:
|
||||
raise ValueError("`stride` must be >= 1.")
|
||||
|
||||
@property
|
||||
def is_dynamic_start(self):
|
||||
return not isinstance(self.start, int)
|
||||
|
||||
@property
|
||||
def is_dynamic_size(self):
|
||||
return not isinstance(self.size, int)
|
||||
|
||||
def tree_flatten(self):
|
||||
# If `start` is statically known, we treat it as static information
|
||||
if isinstance(self.start, int):
|
||||
return (), (self.start, self.size, self.stride)
|
||||
return (self.start,), (self.size, self.stride)
|
||||
xs = ()
|
||||
data = ()
|
||||
xs += (self.start,) if self.is_dynamic_start else (None,)
|
||||
data += (None,) if self.is_dynamic_start else (self.start,)
|
||||
xs += (self.size,) if self.is_dynamic_size else (None,)
|
||||
data += (None,) if self.is_dynamic_size else (self.size,)
|
||||
data += (self.stride,)
|
||||
return xs, data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children) -> Slice:
|
||||
return cls(*children, *aux_data)
|
||||
start, size = [
|
||||
a if a is not None else b for a, b in zip(children, aux_data[:2])
|
||||
]
|
||||
return cls(start, size, aux_data[2])
|
||||
|
||||
@classmethod
|
||||
def from_slice(cls, slc: slice, size: int) -> Slice:
|
||||
@ -61,7 +75,7 @@ class Slice:
|
||||
|
||||
def dslice(
|
||||
start: int | Array | None,
|
||||
size: int | None = None,
|
||||
size: int | Array | None = None,
|
||||
stride: int | None = None,
|
||||
) -> slice | Slice:
|
||||
"""Constructs a `Slice` from a start and a size."""
|
||||
@ -154,6 +168,10 @@ class NDIndexer:
|
||||
f" {self.int_indexer_shape=}"
|
||||
) from e
|
||||
|
||||
@property
|
||||
def is_dynamic_size(self):
|
||||
return any(isinstance(i, Slice) and i.is_dynamic_size for i in self.indices)
|
||||
|
||||
def tree_flatten(self):
|
||||
flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
|
||||
return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)
|
||||
@ -202,7 +220,7 @@ class NDIndexer:
|
||||
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
|
||||
return NDIndexer(tuple(indices), shape, bcast_shape, validate=True)
|
||||
|
||||
def get_indexer_shape(self) -> tuple[int, ...]:
|
||||
def get_indexer_shape(self) -> tuple[int | Array, ...]:
|
||||
_, slice_indexers, _ = unpack_ndindexer(self)
|
||||
slice_shape = [s.size for s in slice_indexers]
|
||||
# In NDIndexers, the int_indexer_shape is *always* at the front of the
|
||||
|
@ -148,11 +148,12 @@ def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> No
|
||||
|
||||
|
||||
def _shape_after_indexing(
|
||||
shape: tuple[int, ...], indexers: tuple[indexing.NDIndexer, ...]
|
||||
) -> tuple[int, ...]:
|
||||
shape: tuple[int | Array, ...], indexers: tuple[indexing.NDIndexer, ...]
|
||||
) -> tuple[int | Array, ...]:
|
||||
for indexer in indexers:
|
||||
# Run some simple checks that all the indexers have consistent shapes
|
||||
assert indexer.shape == shape, (indexer.shape, shape)
|
||||
if not indexer.is_dynamic_size:
|
||||
assert indexer.shape == shape, (indexer.shape, shape)
|
||||
shape = indexer.get_indexer_shape()
|
||||
return shape
|
||||
|
||||
@ -239,12 +240,26 @@ def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice
|
||||
start, size = slc.start, slc.size
|
||||
if isinstance(start, core.Var):
|
||||
start_str = core.pp_var(start, context)
|
||||
end_str = f'{start_str}+{size}'
|
||||
size_str = (
|
||||
core.pp_var(size, context)
|
||||
if isinstance(size, core.Var)
|
||||
else str(size)
|
||||
)
|
||||
return f'{start_str}:{start_str}+{size_str}'
|
||||
else:
|
||||
start_str = '' if start == 0 else str(start)
|
||||
end = start + size
|
||||
end_str = '' if end == dim else str(end)
|
||||
return f'{start_str}:{end_str}'
|
||||
start_str = str(start)
|
||||
if start == 0:
|
||||
start_str = ''
|
||||
if isinstance(size, core.Var):
|
||||
size_str = core.pp_var(size, context)
|
||||
if start_str:
|
||||
return f'{start_str}:{start_str}+{size_str}'
|
||||
else:
|
||||
return f':{size_str}'
|
||||
else:
|
||||
end = start + size
|
||||
end_str = '' if end == dim else str(end)
|
||||
return f'{start_str}:{end_str}'
|
||||
|
||||
def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer
|
||||
) -> pp.Doc:
|
||||
|
@ -95,7 +95,11 @@ class RefView:
|
||||
indexers: tuple[indexing.NDIndexer, ...]
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
def is_dynamic_size(self):
|
||||
return self.indexers[-1].is_dynamic_size
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int | Array, ...]:
|
||||
assert (
|
||||
len(self.indexers) > 0
|
||||
), "Should not be able to create a trivial RefView"
|
||||
|
@ -2061,5 +2061,69 @@ class PallasCallPipelineTest(parameterized.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class PallasCallDynamicDMATest(PallasTPUTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest('DMAs not supported on TPU generations <= 3')
|
||||
|
||||
def test_simple_tile_aligned_dynamic_size_dma(self):
|
||||
|
||||
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
|
||||
size = size_smem_ref[0]
|
||||
pltpu.async_copy(
|
||||
x_hbm_ref.at[pl.ds(0, size)],
|
||||
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
|
||||
|
||||
x = jnp.tile(jnp.arange(8, dtype=jnp.int32)[:, None, None], [1, 8, 128])
|
||||
o = jnp.zeros((8, 8, 128), dtype=jnp.int32)
|
||||
size = jnp.array([4], dtype=jnp.int32)
|
||||
|
||||
out = pl.pallas_call(
|
||||
kernel,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM),
|
||||
pl.BlockSpec(memory_space=pltpu.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.ANY)],
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
|
||||
scratch_shapes=[pltpu.SemaphoreType.DMA]
|
||||
),
|
||||
out_shape=o,
|
||||
input_output_aliases={2: 0},
|
||||
)(size, x, o)
|
||||
expected = o.at[:4].set(x.at[:4].get())
|
||||
np.testing.assert_array_equal(out, expected)
|
||||
|
||||
def test_simple_dynamic_size_dma(self):
|
||||
self.skipTest("doesn't work yet.")
|
||||
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
|
||||
size = size_smem_ref[0]
|
||||
pltpu.async_copy(
|
||||
x_hbm_ref.at[pl.ds(0, size)],
|
||||
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
|
||||
|
||||
x = jnp.arange(8, dtype=jnp.int32)
|
||||
o = jnp.zeros(8, dtype=jnp.int32)
|
||||
size = jnp.array([4], dtype=jnp.int32)
|
||||
|
||||
out = pl.pallas_call(
|
||||
kernel,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM),
|
||||
pl.BlockSpec(memory_space=pltpu.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.ANY)],
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
|
||||
scratch_shapes=[pltpu.SemaphoreType.DMA]
|
||||
),
|
||||
out_shape=o,
|
||||
input_output_aliases={2: 0},
|
||||
)(size, x, o)
|
||||
expected = o.at[:4].set(x.at[:4].get())
|
||||
np.testing.assert_array_equal(out, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user