Slightly rearranged NDIndexer.from_indices_shape and added missing tests

This commit is contained in:
Sergei Lebedev 2024-05-29 16:20:01 +01:00
parent 7c471e2533
commit befa10c1d7
2 changed files with 81 additions and 52 deletions

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import dataclasses
from typing import Any, Union, List
from typing import Any, Union
from jax._src import core
from jax._src import tree_util
@ -30,7 +30,12 @@ import numpy as np
@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
"""Represents a slice with a dynamic start index and a fixed size."""
"""A slice with a start index and a size.
Both start index and size can either be static, i.e. known at tracing
and compilation time, or dynamic.
"""
start: int | Array
size: int | Array
stride: int = 1
@ -78,7 +83,15 @@ def dslice(
size: int | Array | None = None,
stride: int | None = None,
) -> slice | Slice:
"""Constructs a `Slice` from a start and a size."""
"""Constructs a ``Slice`` from a start index and a size.
The semantics of ``dslice`` mirror those of the builtin ``slice`` type:
* ``dslice(None)`` is ``:``
* ``dslice(j)`` is ``:j``
* ``dslice(i, j)`` is ``i:i+j``
* ``dslice(i, j, stride)`` is ``i:i+j:stride``
"""
if start is None:
return slice(None)
if stride is None:
@ -123,6 +136,7 @@ class NDIndexer:
indices: tuple[DimIndexer, ...]
shape: tuple[int, ...]
int_indexer_shape: tuple[int, ...]
# Off by default to avoid doing validation during pytree operations.
validate: bool = False
def __post_init__(self):
@ -181,53 +195,60 @@ class NDIndexer:
def tree_unflatten(cls, data, flat_idx):
idx_tree, shape, int_indexer_shape = data
indices = tree_util.tree_unflatten(idx_tree, flat_idx)
return NDIndexer(tuple(indices), shape, int_indexer_shape)
return cls(tuple(indices), shape, int_indexer_shape)
@classmethod
def from_indices_shape(cls, indices, shape) -> NDIndexer:
if not isinstance(indices, tuple):
# TODO(slebedev): Consider requiring `indices` to be a Sequence.
indices = (indices,)
if len(indices) == 1 and indices[0] is ...:
indices = (slice(None),) * len(shape)
if any(idx is ... for idx in indices):
new_indices : List[Any] = []
num_ellipsis = sum(1 for idx in indices if idx is ...)
indices = list(indices)
if num_ellipsis := sum(idx is ... for idx in indices):
if num_ellipsis > 1:
raise ValueError("Only one ellipsis is supported.")
for idx in indices:
if idx is ...:
expand = (slice(None),) * (len(shape) - len(indices) + 1)
new_indices.extend(expand)
else:
new_indices.append(idx)
indices = tuple(new_indices)
# Expand ... so that `indices` has the same length as `shape`.
ip = indices.index(...)
indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1)
if len(indices) > len(shape):
indices = tuple(indices)
raise ValueError("`indices` must not be longer than `shape`: "
f"{indices=}, {shape=}")
# Pad out indices with slice(None)
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
# Convert all `slice`s to `Slice`s
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
else i for i, s in zip(indices, shape))
elif len(indices) < len(shape):
# Pad `indices` to have the same length as `shape`.
indices.extend([slice(None)] * (len(shape) - len(indices)))
# Promote all builtin `slice`s to `Slice`.
indices = tuple(
Slice.from_slice(i, s) if isinstance(i, slice) else i
for i, s in zip(indices, shape))
is_int_indexing = [not isinstance(i, Slice) for i in indices]
if any(is_int_indexing):
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
indexer_shapes = [core.get_aval(i).shape for i in int_indexers]
if indexer_shapes:
indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers)
try:
bcast_shape = np.broadcast_shapes(*indexer_shapes)
int_indexer_shape = np.broadcast_shapes(*indexer_shapes)
except ValueError as e:
# Raise a nicer error than the NumPy one.
raise ValueError("Cannot broadcast shapes for indexing: "
f"{tuple(a for a in indexer_shapes)}") from e
else:
bcast_shape = ()
raise ValueError(
f"Cannot broadcast shapes for indexing: {indexer_shapes}") from e
# Here we use the `broadcast_to` primitive instead of composing lax
# primitives together because it is easier to lower in targets like
# Triton/Mosaic.
#
# The local import avoids a circular dependency between primitives
# and this module.
from jax._src.state import primitives as sp # pytype: disable=import-error
int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers]
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape, validate=True)
int_indexers = [
sp.broadcast_to(i, int_indexer_shape) for i in int_indexers
]
indices = tuple(merge_lists(is_int_indexing, other_indexers, int_indexers))
else:
int_indexer_shape = ()
return cls(indices, shape, int_indexer_shape, validate=True)
def get_indexer_shape(self) -> tuple[int | Array, ...]:
_, slice_indexers, _ = unpack_ndindexer(self)

View File

@ -19,6 +19,7 @@ from __future__ import annotations
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax._src import util
@ -99,25 +100,18 @@ class IndexerTest(jtu.JaxTestCase):
def test_invalid_ndindexer(self):
indices = (0, 0, 0)
shape = (5, 5)
with self.assertRaises(ValueError):
with self.assertRaisesRegex(
ValueError, "`indices` must not be longer than `shape`"
):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_invalid_ndindexer_oob_int(self):
indices = (4, 0)
shape = (3, 5)
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_invalid_ndindexer_oob_slice_start(self):
indices = (slice(3, 2), 0)
shape = (3, 5)
with self.assertRaises(ValueError):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_invalid_ndindexer_oob_slice_end(self):
indices = (Slice(2, 2), 0)
shape = (3, 5)
with self.assertRaises(ValueError):
@parameterized.parameters(
((4, 0), (3, 5)),
((slice(3, 2), 0), (3, 5)),
((Slice(2, 2), 0), (3, 5)),
)
def test_invalid_ndindexer_oob(self, indices, shape):
with self.assertRaisesRegex(ValueError, "Out of bound"):
_ = NDIndexer.from_indices_shape(indices, shape)
def test_ndindexer_with_padding(self):
@ -126,6 +120,12 @@ class IndexerTest(jtu.JaxTestCase):
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), shape)
def test_ndindexer_with_ellipsis(self):
indices = (..., 4)
shape = (5, 5)
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
def test_ndindexer_with_slices(self):
indices = (slice(2, 3), slice(4, 7))
shape = (5, 6)
@ -154,6 +154,14 @@ class IndexerTest(jtu.JaxTestCase):
indexer = NDIndexer.from_indices_shape(indices, shape)
self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20))
def test_ndindexer_with_arrays_and_invalid_broadcasting(self):
indices = (np.arange(10)[None], np.arange(20)[None, :])
shape = (5, 5)
with self.assertRaisesRegex(
ValueError, "Cannot broadcast shapes for indexing"
):
indexer = NDIndexer.from_indices_shape(indices, shape)
def test_indexer_with_all_types(self):
indices = (0, slice(10), np.arange(5))
shape = (2, 3, 4)