mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas] Add stride in Pallas dynamic slice and support strided load/store.
PiperOrigin-RevId: 615940113
This commit is contained in:
parent
1cef1d9503
commit
2048e3c226
@ -55,7 +55,6 @@ from jax._src.util import safe_map
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import split_list
|
||||
from jax._src.util import unzip2
|
||||
from jax._src.util import unzip3
|
||||
from jax.experimental.mosaic.dialects import tpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -746,47 +745,71 @@ def _maybe_cast_to_index(cast_to_index, x):
|
||||
return _make_index(x)
|
||||
return _ensure_mlir_value(x, aval=jax_core.ShapedArray((), jnp.int32))
|
||||
|
||||
def _index_to_start_size(idx: tuple[indexing.Slice | int | ir.Value, ...],
|
||||
cast_to_index: bool) -> tuple[ir.Value, int, bool]:
|
||||
|
||||
def _index_to_start_size_stride(
|
||||
idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool
|
||||
) -> tuple[ir.Value, int, int, bool]:
|
||||
assert not isinstance(idx, slice)
|
||||
if isinstance(idx, indexing.Slice):
|
||||
start = _maybe_cast_to_index(cast_to_index, idx.start)
|
||||
size = idx.size
|
||||
stride = idx.stride
|
||||
squeeze = False
|
||||
elif isinstance(idx, int):
|
||||
start = _maybe_cast_to_index(cast_to_index, idx)
|
||||
size = 1
|
||||
stride = 1
|
||||
squeeze = True
|
||||
else:
|
||||
if np.shape(idx):
|
||||
raise ValueError(f"Can only use ()-shaped and slice indexing: {idx}")
|
||||
start = _maybe_cast_to_index(cast_to_index, idx)
|
||||
size = 1
|
||||
stride = 1
|
||||
squeeze = True
|
||||
return start, size, squeeze
|
||||
return start, size, stride, squeeze
|
||||
|
||||
|
||||
def _indexer_to_start_size(
|
||||
indexer: NDIndexer, ref_block_shape: tuple[int | pl_core.Mapped, ...], *,
|
||||
def _indexer_to_start_size_stride(
|
||||
indexer: NDIndexer,
|
||||
ref_block_shape: tuple[int | pl_core.Mapped, ...],
|
||||
*,
|
||||
cast_to_index: bool,
|
||||
) -> tuple[tuple[ir.Value, ...], tuple[int, ...], tuple[bool, ...],
|
||||
tuple[int | pl_core.Mapped, ...]]:
|
||||
) -> tuple[
|
||||
tuple[ir.Value, ...],
|
||||
tuple[int, ...],
|
||||
tuple[int, ...],
|
||||
tuple[bool, ...],
|
||||
tuple[int | pl_core.Mapped, ...],
|
||||
]:
|
||||
indices_iter = iter(indexer.indices)
|
||||
starts, sizes, squeeze_dims = unzip3(
|
||||
(
|
||||
_maybe_cast_to_index(cast_to_index, 0),
|
||||
1,
|
||||
True,
|
||||
)
|
||||
if s is pl_core.mapped
|
||||
else _index_to_start_size(next(indices_iter), cast_to_index)
|
||||
for s in ref_block_shape
|
||||
)
|
||||
starts, sizes, strides, squeeze_dims = [], [], [], []
|
||||
for s in ref_block_shape:
|
||||
start, size, stride, squeeze_dim = (
|
||||
(
|
||||
_maybe_cast_to_index(cast_to_index, 0),
|
||||
1,
|
||||
1,
|
||||
True,
|
||||
)
|
||||
if s is pl_core.mapped
|
||||
else _index_to_start_size_stride(next(indices_iter), cast_to_index)
|
||||
)
|
||||
starts.append(start)
|
||||
sizes.append(size)
|
||||
strides.append(stride)
|
||||
squeeze_dims.append(squeeze_dim)
|
||||
next_index = next(indices_iter, None)
|
||||
assert next_index is None, (indexer.indices, ref_block_shape)
|
||||
new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims)
|
||||
if not squeeze)
|
||||
return tuple(starts), tuple(sizes), tuple(squeeze_dims), new_ref_block_shape
|
||||
return (
|
||||
tuple(starts),
|
||||
tuple(sizes),
|
||||
tuple(strides),
|
||||
tuple(squeeze_dims),
|
||||
new_ref_block_shape,
|
||||
)
|
||||
|
||||
|
||||
def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
|
||||
@ -796,9 +819,15 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
|
||||
tuple[int | pl_core.Mapped, ...]]:
|
||||
assert ref_block_shape is not None
|
||||
target_shape = indexer.get_indexer_shape()
|
||||
starts, sizes, squeeze_dims, ref_block_shape = _indexer_to_start_size(
|
||||
indexer, ref_block_shape, cast_to_index=False,
|
||||
starts, sizes, strides, squeeze_dims, ref_block_shape = (
|
||||
_indexer_to_start_size_stride(
|
||||
indexer,
|
||||
ref_block_shape,
|
||||
cast_to_index=False,
|
||||
)
|
||||
)
|
||||
if not all((s is None or s == 1) for s in strides):
|
||||
raise NotImplementedError("Strided slices of references are unsupported.")
|
||||
target_ref_ty = ir.MemRefType.get(
|
||||
tuple(sizes), _dtype_to_ir_type(ref_aval.dtype),
|
||||
memory_space=ref.type.memory_space)
|
||||
@ -846,14 +875,21 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
|
||||
for a in idx_aval.indices
|
||||
):
|
||||
raise ValueError("Cannot do int indexing on TPU")
|
||||
starts, sizes, _, _ = _indexer_to_start_size(
|
||||
idx, ref_block_shape, cast_to_index=True,
|
||||
starts, sizes, strides, _, _ = _indexer_to_start_size_stride(
|
||||
idx,
|
||||
ref_block_shape,
|
||||
cast_to_index=True,
|
||||
)
|
||||
need_stride = not all((s is None or s == 1) for s in strides)
|
||||
load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype)
|
||||
if is_smem_load:
|
||||
if ctx.avals_out[0].shape:
|
||||
raise ValueError("Can only load scalars from SMEM")
|
||||
return memref.LoadOp(ref, starts).result
|
||||
if need_stride:
|
||||
load_val = tpu.StridedLoadOp(
|
||||
aval_to_ir_type(load_aval), ref, starts, strides
|
||||
).result
|
||||
else:
|
||||
load_val = vector.LoadOp(aval_to_ir_type(load_aval), ref, starts).result
|
||||
if load_aval == aval_out:
|
||||
@ -896,10 +932,12 @@ def _masked_swap_lowering_rule(
|
||||
raise NotImplementedError(
|
||||
"Indexing into a ()-shaped Ref not yet supported on TPU.")
|
||||
|
||||
starts, _, _, _ = _indexer_to_start_size(
|
||||
idx, ref_block_shape, cast_to_index=True,
|
||||
starts, _, strides, _, _ = _indexer_to_start_size_stride(
|
||||
idx,
|
||||
ref_block_shape,
|
||||
cast_to_index=True,
|
||||
)
|
||||
|
||||
need_stride = not all((s is None or s == 1) for s in strides)
|
||||
if is_smem_store:
|
||||
if val_aval.shape:
|
||||
raise ValueError("Can only store scalars to SMEM")
|
||||
@ -918,7 +956,10 @@ def _masked_swap_lowering_rule(
|
||||
mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
|
||||
mem_aval_vec_type = ir.VectorType.get(mem_aval.shape,
|
||||
_dtype_to_ir_type(mem_aval.dtype))
|
||||
result = vector.LoadOp(mem_aval_vec_type, ref, starts).result
|
||||
if need_stride:
|
||||
result = tpu.StridedLoadOp(mem_aval_vec_type, ref, starts, strides).result
|
||||
else:
|
||||
result = vector.LoadOp(mem_aval_vec_type, ref, starts).result
|
||||
if mem_aval != aval_out:
|
||||
# We are slicing a scalar so provided dummy 1 indices
|
||||
result_vec_type = ir.VectorType.get(aval_out.shape,
|
||||
@ -927,7 +968,10 @@ def _masked_swap_lowering_rule(
|
||||
val_vec_type = ir.VectorType.get(mem_aval.shape,
|
||||
_dtype_to_ir_type(mem_aval.dtype))
|
||||
val = vector.ShapeCastOp(val_vec_type, val).result
|
||||
vector.StoreOp(val, ref, starts)
|
||||
if need_stride:
|
||||
tpu.StridedStoreOp(val, ref, starts, strides)
|
||||
else:
|
||||
vector.StoreOp(val, ref, starts)
|
||||
return result
|
||||
|
||||
|
||||
|
@ -33,16 +33,19 @@ class Slice:
|
||||
"""Represents a slice with a dynamic start index and a fixed size."""
|
||||
start: Any
|
||||
size: int
|
||||
stride: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.size < 0:
|
||||
raise ValueError("`size` must not be negative.")
|
||||
if self.stride < 1:
|
||||
raise ValueError("`stride` must be >= 1.")
|
||||
|
||||
def tree_flatten(self):
|
||||
# If `start` is statically known, we treat it as static information
|
||||
if isinstance(self.start, int):
|
||||
return (), (self.start, self.size)
|
||||
return (self.start,), (self.size,)
|
||||
return (), (self.start, self.size, self.stride)
|
||||
return (self.start,), (self.size, self.stride)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children) -> Slice:
|
||||
@ -51,21 +54,30 @@ class Slice:
|
||||
@classmethod
|
||||
def from_slice(cls, slc: slice, size: int) -> Slice:
|
||||
start, stop, step = slc.indices(size)
|
||||
if step != 1:
|
||||
raise ValueError(f"slice must have a step of 1 (found: {step})")
|
||||
return cls(start, max(stop - start, 0))
|
||||
if step < 1:
|
||||
raise ValueError(f"slice must have a step >= 1 (found: {step})")
|
||||
return cls(start, max((stop - start + step - 1) // step, 0), step)
|
||||
|
||||
|
||||
def dslice(start: int | Array | None, size: int | None = None
|
||||
) -> slice | Slice:
|
||||
def dslice(
|
||||
start: int | Array | None,
|
||||
size: int | None = None,
|
||||
stride: int | None = None,
|
||||
) -> slice | Slice:
|
||||
"""Constructs a `Slice` from a start and a size."""
|
||||
if start is None:
|
||||
return slice(None)
|
||||
if stride is None:
|
||||
stride = 1
|
||||
if not isinstance(stride, int):
|
||||
raise ValueError("Non-static stride in `dslice`")
|
||||
if size is None:
|
||||
if not isinstance(start, int):
|
||||
raise ValueError("Non-static `dslice`")
|
||||
return Slice(0, start)
|
||||
return Slice(start, size)
|
||||
return Slice(0, start, stride)
|
||||
return Slice(start, size, stride)
|
||||
|
||||
|
||||
ds = dslice # Handy alias
|
||||
|
||||
|
||||
@ -113,9 +125,10 @@ class NDIndexer:
|
||||
if value := _maybe_concretize(start):
|
||||
if value >= s:
|
||||
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
|
||||
if value + idx.size > s:
|
||||
if value + (idx.size - 1) * idx.stride >= s:
|
||||
raise ValueError(
|
||||
f"Out of bound slice: start={value}, size={idx.size}, dim={s}."
|
||||
f"Out of bound slice: start={value}, size={idx.size},"
|
||||
f" stride={idx.stride}, dim={s}."
|
||||
)
|
||||
continue
|
||||
# The shape of indexer integers should be broadcastable up to the
|
||||
|
Loading…
x
Reference in New Issue
Block a user