diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 2a0caaa95..984692f3f 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -55,7 +55,6 @@ from jax._src.util import safe_map from jax._src.util import safe_zip from jax._src.util import split_list from jax._src.util import unzip2 -from jax._src.util import unzip3 from jax.experimental.mosaic.dialects import tpu import jax.numpy as jnp import numpy as np @@ -746,47 +745,71 @@ def _maybe_cast_to_index(cast_to_index, x): return _make_index(x) return _ensure_mlir_value(x, aval=jax_core.ShapedArray((), jnp.int32)) -def _index_to_start_size(idx: tuple[indexing.Slice | int | ir.Value, ...], - cast_to_index: bool) -> tuple[ir.Value, int, bool]: + +def _index_to_start_size_stride( + idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool +) -> tuple[ir.Value, int, int, bool]: assert not isinstance(idx, slice) if isinstance(idx, indexing.Slice): start = _maybe_cast_to_index(cast_to_index, idx.start) size = idx.size + stride = idx.stride squeeze = False elif isinstance(idx, int): start = _maybe_cast_to_index(cast_to_index, idx) size = 1 + stride = 1 squeeze = True else: if np.shape(idx): raise ValueError(f"Can only use ()-shaped and slice indexing: {idx}") start = _maybe_cast_to_index(cast_to_index, idx) size = 1 + stride = 1 squeeze = True - return start, size, squeeze + return start, size, stride, squeeze -def _indexer_to_start_size( - indexer: NDIndexer, ref_block_shape: tuple[int | pl_core.Mapped, ...], *, +def _indexer_to_start_size_stride( + indexer: NDIndexer, + ref_block_shape: tuple[int | pl_core.Mapped, ...], + *, cast_to_index: bool, -) -> tuple[tuple[ir.Value, ...], tuple[int, ...], tuple[bool, ...], - tuple[int | pl_core.Mapped, ...]]: +) -> tuple[ + tuple[ir.Value, ...], + tuple[int, ...], + tuple[int, ...], + tuple[bool, ...], + tuple[int | pl_core.Mapped, ...], +]: indices_iter = iter(indexer.indices) - starts, sizes, squeeze_dims = unzip3( - ( - _maybe_cast_to_index(cast_to_index, 0), - 1, - True, - ) - if s is pl_core.mapped - else _index_to_start_size(next(indices_iter), cast_to_index) - for s in ref_block_shape - ) + starts, sizes, strides, squeeze_dims = [], [], [], [] + for s in ref_block_shape: + start, size, stride, squeeze_dim = ( + ( + _maybe_cast_to_index(cast_to_index, 0), + 1, + 1, + True, + ) + if s is pl_core.mapped + else _index_to_start_size_stride(next(indices_iter), cast_to_index) + ) + starts.append(start) + sizes.append(size) + strides.append(stride) + squeeze_dims.append(squeeze_dim) next_index = next(indices_iter, None) assert next_index is None, (indexer.indices, ref_block_shape) new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims) if not squeeze) - return tuple(starts), tuple(sizes), tuple(squeeze_dims), new_ref_block_shape + return ( + tuple(starts), + tuple(sizes), + tuple(strides), + tuple(squeeze_dims), + new_ref_block_shape, + ) def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, @@ -796,9 +819,15 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, tuple[int | pl_core.Mapped, ...]]: assert ref_block_shape is not None target_shape = indexer.get_indexer_shape() - starts, sizes, squeeze_dims, ref_block_shape = _indexer_to_start_size( - indexer, ref_block_shape, cast_to_index=False, + starts, sizes, strides, squeeze_dims, ref_block_shape = ( + _indexer_to_start_size_stride( + indexer, + ref_block_shape, + cast_to_index=False, + ) ) + if not all((s is None or s == 1) for s in strides): + raise NotImplementedError("Strided slices of references are unsupported.") target_ref_ty = ir.MemRefType.get( tuple(sizes), _dtype_to_ir_type(ref_aval.dtype), memory_space=ref.type.memory_space) @@ -846,14 +875,21 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): for a in idx_aval.indices ): raise ValueError("Cannot do int indexing on TPU") - starts, sizes, _, _ = _indexer_to_start_size( - idx, ref_block_shape, cast_to_index=True, + starts, sizes, strides, _, _ = _indexer_to_start_size_stride( + idx, + ref_block_shape, + cast_to_index=True, ) + need_stride = not all((s is None or s == 1) for s in strides) load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype) if is_smem_load: if ctx.avals_out[0].shape: raise ValueError("Can only load scalars from SMEM") return memref.LoadOp(ref, starts).result + if need_stride: + load_val = tpu.StridedLoadOp( + aval_to_ir_type(load_aval), ref, starts, strides + ).result else: load_val = vector.LoadOp(aval_to_ir_type(load_aval), ref, starts).result if load_aval == aval_out: @@ -896,10 +932,12 @@ def _masked_swap_lowering_rule( raise NotImplementedError( "Indexing into a ()-shaped Ref not yet supported on TPU.") - starts, _, _, _ = _indexer_to_start_size( - idx, ref_block_shape, cast_to_index=True, + starts, _, strides, _, _ = _indexer_to_start_size_stride( + idx, + ref_block_shape, + cast_to_index=True, ) - + need_stride = not all((s is None or s == 1) for s in strides) if is_smem_store: if val_aval.shape: raise ValueError("Can only store scalars to SMEM") @@ -918,7 +956,10 @@ def _masked_swap_lowering_rule( mem_aval = aval_out.update(shape=tuple(mem_slice_shape)) mem_aval_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype)) - result = vector.LoadOp(mem_aval_vec_type, ref, starts).result + if need_stride: + result = tpu.StridedLoadOp(mem_aval_vec_type, ref, starts, strides).result + else: + result = vector.LoadOp(mem_aval_vec_type, ref, starts).result if mem_aval != aval_out: # We are slicing a scalar so provided dummy 1 indices result_vec_type = ir.VectorType.get(aval_out.shape, @@ -927,7 +968,10 @@ def _masked_swap_lowering_rule( val_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype)) val = vector.ShapeCastOp(val_vec_type, val).result - vector.StoreOp(val, ref, starts) + if need_stride: + tpu.StridedStoreOp(val, ref, starts, strides) + else: + vector.StoreOp(val, ref, starts) return result diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 175d88398..60b99f100 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -33,16 +33,19 @@ class Slice: """Represents a slice with a dynamic start index and a fixed size.""" start: Any size: int + stride: int = 1 def __post_init__(self): if self.size < 0: raise ValueError("`size` must not be negative.") + if self.stride < 1: + raise ValueError("`stride` must be >= 1.") def tree_flatten(self): # If `start` is statically known, we treat it as static information if isinstance(self.start, int): - return (), (self.start, self.size) - return (self.start,), (self.size,) + return (), (self.start, self.size, self.stride) + return (self.start,), (self.size, self.stride) @classmethod def tree_unflatten(cls, aux_data, children) -> Slice: @@ -51,21 +54,30 @@ class Slice: @classmethod def from_slice(cls, slc: slice, size: int) -> Slice: start, stop, step = slc.indices(size) - if step != 1: - raise ValueError(f"slice must have a step of 1 (found: {step})") - return cls(start, max(stop - start, 0)) + if step < 1: + raise ValueError(f"slice must have a step >= 1 (found: {step})") + return cls(start, max((stop - start + step - 1) // step, 0), step) -def dslice(start: int | Array | None, size: int | None = None - ) -> slice | Slice: +def dslice( + start: int | Array | None, + size: int | None = None, + stride: int | None = None, +) -> slice | Slice: """Constructs a `Slice` from a start and a size.""" if start is None: return slice(None) + if stride is None: + stride = 1 + if not isinstance(stride, int): + raise ValueError("Non-static stride in `dslice`") if size is None: if not isinstance(start, int): raise ValueError("Non-static `dslice`") - return Slice(0, start) - return Slice(start, size) + return Slice(0, start, stride) + return Slice(start, size, stride) + + ds = dslice # Handy alias @@ -113,9 +125,10 @@ class NDIndexer: if value := _maybe_concretize(start): if value >= s: raise ValueError(f"Out of bound slice: start={value}, dim={s}.") - if value + idx.size > s: + if value + (idx.size - 1) * idx.stride >= s: raise ValueError( - f"Out of bound slice: start={value}, size={idx.size}, dim={s}." + f"Out of bound slice: start={value}, size={idx.size}," + f" stride={idx.stride}, dim={s}." ) continue # The shape of indexer integers should be broadcastable up to the