Fix indexing corner case with empty ellipses

This commit is contained in:
Jake VanderPlas 2024-12-03 17:20:40 -08:00
parent 40122f7c03
commit f6f4ef06cd
2 changed files with 16 additions and 5 deletions

View File

@ -11971,6 +11971,14 @@ def _int(aval):
def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
normalize_indices: bool = True) -> _Indexer:
# Check whether advanced indices are contiguous. We must do this before
# removing ellipses (https://github.com/jax-ml/jax/issues/25109)
# If advanced idexing axes do not appear contiguously, NumPy semantics
# move the advanced axes to the front.
is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray))
or isscalar(e) for e in idx])
advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1)
# Remove ellipses and add trailing slice(None)s.
idx = _canonicalize_tuple_index(len(x_shape), idx)
@ -11987,10 +11995,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
# Do the advanced indexing axes appear contiguously? If not, NumPy semantics
# move the advanced axes to the front.
advanced_axes_are_contiguous = False
advanced_indexes: Sequence[Array | np.ndarray] | None = None
# The positions of the advanced indexing axes in `idx`.
@ -12009,7 +12013,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
for e, i, j in advanced_pairs)
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
advanced_axes_are_contiguous = bool(np.all(np.diff(idx_advanced_axes) == 1))
x_axis = 0 # Current axis in x.
y_axis = 0 # Current axis in y, before collapsing. See below.

View File

@ -399,6 +399,14 @@ MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [
IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)),
out_shape=(3,)),
]),
("EllipsisWithArrayIndices", [
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 1]), ..., np.array([0, 1])),
out_shape=(2, 4)),
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), np.array([0, 1]), ..., np.array([0, 1])),
out_shape=(2, 3)),
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), ..., np.array([0, 1]), np.array([0, 1])),
out_shape=(3, 2)),
]),
]