mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Slightly rearranged NDIndexer.from_indices_shape and added missing tests
This commit is contained in:
parent
7c471e2533
commit
befa10c1d7
@ -17,7 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Union, List
|
||||
from typing import Any, Union
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import tree_util
|
||||
@ -30,7 +30,12 @@ import numpy as np
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class Slice:
|
||||
"""Represents a slice with a dynamic start index and a fixed size."""
|
||||
"""A slice with a start index and a size.
|
||||
|
||||
Both start index and size can either be static, i.e. known at tracing
|
||||
and compilation time, or dynamic.
|
||||
"""
|
||||
|
||||
start: int | Array
|
||||
size: int | Array
|
||||
stride: int = 1
|
||||
@ -78,7 +83,15 @@ def dslice(
|
||||
size: int | Array | None = None,
|
||||
stride: int | None = None,
|
||||
) -> slice | Slice:
|
||||
"""Constructs a `Slice` from a start and a size."""
|
||||
"""Constructs a ``Slice`` from a start index and a size.
|
||||
|
||||
The semantics of ``dslice`` mirror those of the builtin ``slice`` type:
|
||||
|
||||
* ``dslice(None)`` is ``:``
|
||||
* ``dslice(j)`` is ``:j``
|
||||
* ``dslice(i, j)`` is ``i:i+j``
|
||||
* ``dslice(i, j, stride)`` is ``i:i+j:stride``
|
||||
"""
|
||||
if start is None:
|
||||
return slice(None)
|
||||
if stride is None:
|
||||
@ -123,6 +136,7 @@ class NDIndexer:
|
||||
indices: tuple[DimIndexer, ...]
|
||||
shape: tuple[int, ...]
|
||||
int_indexer_shape: tuple[int, ...]
|
||||
# Off by default to avoid doing validation during pytree operations.
|
||||
validate: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
@ -181,53 +195,60 @@ class NDIndexer:
|
||||
def tree_unflatten(cls, data, flat_idx):
|
||||
idx_tree, shape, int_indexer_shape = data
|
||||
indices = tree_util.tree_unflatten(idx_tree, flat_idx)
|
||||
return NDIndexer(tuple(indices), shape, int_indexer_shape)
|
||||
return cls(tuple(indices), shape, int_indexer_shape)
|
||||
|
||||
@classmethod
|
||||
def from_indices_shape(cls, indices, shape) -> NDIndexer:
|
||||
if not isinstance(indices, tuple):
|
||||
# TODO(slebedev): Consider requiring `indices` to be a Sequence.
|
||||
indices = (indices,)
|
||||
if len(indices) == 1 and indices[0] is ...:
|
||||
indices = (slice(None),) * len(shape)
|
||||
if any(idx is ... for idx in indices):
|
||||
new_indices : List[Any] = []
|
||||
num_ellipsis = sum(1 for idx in indices if idx is ...)
|
||||
|
||||
indices = list(indices)
|
||||
if num_ellipsis := sum(idx is ... for idx in indices):
|
||||
if num_ellipsis > 1:
|
||||
raise ValueError("Only one ellipsis is supported.")
|
||||
for idx in indices:
|
||||
if idx is ...:
|
||||
expand = (slice(None),) * (len(shape) - len(indices) + 1)
|
||||
new_indices.extend(expand)
|
||||
else:
|
||||
new_indices.append(idx)
|
||||
indices = tuple(new_indices)
|
||||
# Expand ... so that `indices` has the same length as `shape`.
|
||||
ip = indices.index(...)
|
||||
indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1)
|
||||
if len(indices) > len(shape):
|
||||
indices = tuple(indices)
|
||||
raise ValueError("`indices` must not be longer than `shape`: "
|
||||
f"{indices=}, {shape=}")
|
||||
# Pad out indices with slice(None)
|
||||
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
|
||||
# Convert all `slice`s to `Slice`s
|
||||
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
|
||||
else i for i, s in zip(indices, shape))
|
||||
elif len(indices) < len(shape):
|
||||
# Pad `indices` to have the same length as `shape`.
|
||||
indices.extend([slice(None)] * (len(shape) - len(indices)))
|
||||
|
||||
# Promote all builtin `slice`s to `Slice`.
|
||||
indices = tuple(
|
||||
Slice.from_slice(i, s) if isinstance(i, slice) else i
|
||||
for i, s in zip(indices, shape))
|
||||
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in indices]
|
||||
if any(is_int_indexing):
|
||||
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
|
||||
indexer_shapes = [core.get_aval(i).shape for i in int_indexers]
|
||||
if indexer_shapes:
|
||||
indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers)
|
||||
try:
|
||||
bcast_shape = np.broadcast_shapes(*indexer_shapes)
|
||||
int_indexer_shape = np.broadcast_shapes(*indexer_shapes)
|
||||
except ValueError as e:
|
||||
# Raise a nicer error than the NumPy one.
|
||||
raise ValueError("Cannot broadcast shapes for indexing: "
|
||||
f"{tuple(a for a in indexer_shapes)}") from e
|
||||
else:
|
||||
bcast_shape = ()
|
||||
raise ValueError(
|
||||
f"Cannot broadcast shapes for indexing: {indexer_shapes}") from e
|
||||
|
||||
# Here we use the `broadcast_to` primitive instead of composing lax
|
||||
# primitives together because it is easier to lower in targets like
|
||||
# Triton/Mosaic.
|
||||
#
|
||||
# The local import avoids a circular dependency between primitives
|
||||
# and this module.
|
||||
from jax._src.state import primitives as sp # pytype: disable=import-error
|
||||
int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers]
|
||||
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
|
||||
return NDIndexer(tuple(indices), shape, bcast_shape, validate=True)
|
||||
int_indexers = [
|
||||
sp.broadcast_to(i, int_indexer_shape) for i in int_indexers
|
||||
]
|
||||
indices = tuple(merge_lists(is_int_indexing, other_indexers, int_indexers))
|
||||
else:
|
||||
int_indexer_shape = ()
|
||||
|
||||
return cls(indices, shape, int_indexer_shape, validate=True)
|
||||
|
||||
def get_indexer_shape(self) -> tuple[int | Array, ...]:
|
||||
_, slice_indexers, _ = unpack_ndindexer(self)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import annotations
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
@ -99,25 +100,18 @@ class IndexerTest(jtu.JaxTestCase):
|
||||
def test_invalid_ndindexer(self):
|
||||
indices = (0, 0, 0)
|
||||
shape = (5, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "`indices` must not be longer than `shape`"
|
||||
):
|
||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_invalid_ndindexer_oob_int(self):
|
||||
indices = (4, 0)
|
||||
shape = (3, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_invalid_ndindexer_oob_slice_start(self):
|
||||
indices = (slice(3, 2), 0)
|
||||
shape = (3, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_invalid_ndindexer_oob_slice_end(self):
|
||||
indices = (Slice(2, 2), 0)
|
||||
shape = (3, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
@parameterized.parameters(
|
||||
((4, 0), (3, 5)),
|
||||
((slice(3, 2), 0), (3, 5)),
|
||||
((Slice(2, 2), 0), (3, 5)),
|
||||
)
|
||||
def test_invalid_ndindexer_oob(self, indices, shape):
|
||||
with self.assertRaisesRegex(ValueError, "Out of bound"):
|
||||
_ = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_ndindexer_with_padding(self):
|
||||
@ -126,6 +120,12 @@ class IndexerTest(jtu.JaxTestCase):
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), shape)
|
||||
|
||||
def test_ndindexer_with_ellipsis(self):
|
||||
indices = (..., 4)
|
||||
shape = (5, 5)
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
|
||||
|
||||
def test_ndindexer_with_slices(self):
|
||||
indices = (slice(2, 3), slice(4, 7))
|
||||
shape = (5, 6)
|
||||
@ -154,6 +154,14 @@ class IndexerTest(jtu.JaxTestCase):
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20))
|
||||
|
||||
def test_ndindexer_with_arrays_and_invalid_broadcasting(self):
|
||||
indices = (np.arange(10)[None], np.arange(20)[None, :])
|
||||
shape = (5, 5)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Cannot broadcast shapes for indexing"
|
||||
):
|
||||
indexer = NDIndexer.from_indices_shape(indices, shape)
|
||||
|
||||
def test_indexer_with_all_types(self):
|
||||
indices = (0, slice(10), np.arange(5))
|
||||
shape = (2, 3, 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user