diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 7ed504385..df340c032 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -17,7 +17,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Union, List +from typing import Any, Union from jax._src import core from jax._src import tree_util @@ -30,7 +30,12 @@ import numpy as np @tree_util.register_pytree_node_class @dataclasses.dataclass class Slice: - """Represents a slice with a dynamic start index and a fixed size.""" + """A slice with a start index and a size. + + Both start index and size can either be static, i.e. known at tracing + and compilation time, or dynamic. + """ + start: int | Array size: int | Array stride: int = 1 @@ -78,7 +83,15 @@ def dslice( size: int | Array | None = None, stride: int | None = None, ) -> slice | Slice: - """Constructs a `Slice` from a start and a size.""" + """Constructs a ``Slice`` from a start index and a size. + + The semantics of ``dslice`` mirror those of the builtin ``slice`` type: + + * ``dslice(None)`` is ``:`` + * ``dslice(j)`` is ``:j`` + * ``dslice(i, j)`` is ``i:i+j`` + * ``dslice(i, j, stride)`` is ``i:i+j:stride`` + """ if start is None: return slice(None) if stride is None: @@ -123,6 +136,7 @@ class NDIndexer: indices: tuple[DimIndexer, ...] shape: tuple[int, ...] int_indexer_shape: tuple[int, ...] + # Off by default to avoid doing validation during pytree operations. validate: bool = False def __post_init__(self): @@ -181,53 +195,60 @@ class NDIndexer: def tree_unflatten(cls, data, flat_idx): idx_tree, shape, int_indexer_shape = data indices = tree_util.tree_unflatten(idx_tree, flat_idx) - return NDIndexer(tuple(indices), shape, int_indexer_shape) + return cls(tuple(indices), shape, int_indexer_shape) @classmethod def from_indices_shape(cls, indices, shape) -> NDIndexer: if not isinstance(indices, tuple): + # TODO(slebedev): Consider requiring `indices` to be a Sequence. indices = (indices,) - if len(indices) == 1 and indices[0] is ...: - indices = (slice(None),) * len(shape) - if any(idx is ... for idx in indices): - new_indices : List[Any] = [] - num_ellipsis = sum(1 for idx in indices if idx is ...) + + indices = list(indices) + if num_ellipsis := sum(idx is ... for idx in indices): if num_ellipsis > 1: raise ValueError("Only one ellipsis is supported.") - for idx in indices: - if idx is ...: - expand = (slice(None),) * (len(shape) - len(indices) + 1) - new_indices.extend(expand) - else: - new_indices.append(idx) - indices = tuple(new_indices) + # Expand ... so that `indices` has the same length as `shape`. + ip = indices.index(...) + indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1) if len(indices) > len(shape): + indices = tuple(indices) raise ValueError("`indices` must not be longer than `shape`: " f"{indices=}, {shape=}") - # Pad out indices with slice(None) - indices = [*indices, *[slice(None)] * (len(shape) - len(indices))] - # Convert all `slice`s to `Slice`s - indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice) - else i for i, s in zip(indices, shape)) + elif len(indices) < len(shape): + # Pad `indices` to have the same length as `shape`. + indices.extend([slice(None)] * (len(shape) - len(indices))) + + # Promote all builtin `slice`s to `Slice`. + indices = tuple( + Slice.from_slice(i, s) if isinstance(i, slice) else i + for i, s in zip(indices, shape)) + is_int_indexing = [not isinstance(i, Slice) for i in indices] - other_indexers, int_indexers = partition_list(is_int_indexing, indices) - indexer_shapes = [core.get_aval(i).shape for i in int_indexers] - if indexer_shapes: + if any(is_int_indexing): + other_indexers, int_indexers = partition_list(is_int_indexing, indices) + indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers) try: - bcast_shape = np.broadcast_shapes(*indexer_shapes) + int_indexer_shape = np.broadcast_shapes(*indexer_shapes) except ValueError as e: # Raise a nicer error than the NumPy one. - raise ValueError("Cannot broadcast shapes for indexing: " - f"{tuple(a for a in indexer_shapes)}") from e + raise ValueError( + f"Cannot broadcast shapes for indexing: {indexer_shapes}") from e + + # Here we use the `broadcast_to` primitive instead of composing lax + # primitives together because it is easier to lower in targets like + # Triton/Mosaic. + # + # The local import avoids a circular dependency between primitives + # and this module. + from jax._src.state import primitives as sp # pytype: disable=import-error + int_indexers = [ + sp.broadcast_to(i, int_indexer_shape) for i in int_indexers + ] + indices = tuple(merge_lists(is_int_indexing, other_indexers, int_indexers)) else: - bcast_shape = () - # Here we use the `broadcast_to` primitive instead of composing lax - # primitives together because it is easier to lower in targets like - # Triton/Mosaic. - from jax._src.state import primitives as sp # pytype: disable=import-error - int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers] - indices = merge_lists(is_int_indexing, other_indexers, int_indexers) - return NDIndexer(tuple(indices), shape, bcast_shape, validate=True) + int_indexer_shape = () + + return cls(indices, shape, int_indexer_shape, validate=True) def get_indexer_shape(self) -> tuple[int | Array, ...]: _, slice_indexers, _ = unpack_ndindexer(self) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 799d052dc..c11fca350 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -19,6 +19,7 @@ from __future__ import annotations import unittest from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu from jax._src import util @@ -99,25 +100,18 @@ class IndexerTest(jtu.JaxTestCase): def test_invalid_ndindexer(self): indices = (0, 0, 0) shape = (5, 5) - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, "`indices` must not be longer than `shape`" + ): _ = NDIndexer.from_indices_shape(indices, shape) - def test_invalid_ndindexer_oob_int(self): - indices = (4, 0) - shape = (3, 5) - with self.assertRaises(ValueError): - _ = NDIndexer.from_indices_shape(indices, shape) - - def test_invalid_ndindexer_oob_slice_start(self): - indices = (slice(3, 2), 0) - shape = (3, 5) - with self.assertRaises(ValueError): - _ = NDIndexer.from_indices_shape(indices, shape) - - def test_invalid_ndindexer_oob_slice_end(self): - indices = (Slice(2, 2), 0) - shape = (3, 5) - with self.assertRaises(ValueError): + @parameterized.parameters( + ((4, 0), (3, 5)), + ((slice(3, 2), 0), (3, 5)), + ((Slice(2, 2), 0), (3, 5)), + ) + def test_invalid_ndindexer_oob(self, indices, shape): + with self.assertRaisesRegex(ValueError, "Out of bound"): _ = NDIndexer.from_indices_shape(indices, shape) def test_ndindexer_with_padding(self): @@ -126,6 +120,12 @@ class IndexerTest(jtu.JaxTestCase): indexer = NDIndexer.from_indices_shape(indices, shape) self.assertTupleEqual(indexer.get_indexer_shape(), shape) + def test_ndindexer_with_ellipsis(self): + indices = (..., 4) + shape = (5, 5) + indexer = NDIndexer.from_indices_shape(indices, shape) + self.assertTupleEqual(indexer.get_indexer_shape(), (5,)) + def test_ndindexer_with_slices(self): indices = (slice(2, 3), slice(4, 7)) shape = (5, 6) @@ -154,6 +154,14 @@ class IndexerTest(jtu.JaxTestCase): indexer = NDIndexer.from_indices_shape(indices, shape) self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20)) + def test_ndindexer_with_arrays_and_invalid_broadcasting(self): + indices = (np.arange(10)[None], np.arange(20)[None, :]) + shape = (5, 5) + with self.assertRaisesRegex( + ValueError, "Cannot broadcast shapes for indexing" + ): + indexer = NDIndexer.from_indices_shape(indices, shape) + def test_indexer_with_all_types(self): indices = (0, slice(10), np.arange(5)) shape = (2, 3, 4)