mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix indexing corner case with empty ellipses
This commit is contained in:
parent
40122f7c03
commit
f6f4ef06cd
@ -11971,6 +11971,14 @@ def _int(aval):
|
||||
|
||||
def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
normalize_indices: bool = True) -> _Indexer:
|
||||
# Check whether advanced indices are contiguous. We must do this before
|
||||
# removing ellipses (https://github.com/jax-ml/jax/issues/25109)
|
||||
# If advanced idexing axes do not appear contiguously, NumPy semantics
|
||||
# move the advanced axes to the front.
|
||||
is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray))
|
||||
or isscalar(e) for e in idx])
|
||||
advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1)
|
||||
|
||||
# Remove ellipses and add trailing slice(None)s.
|
||||
idx = _canonicalize_tuple_index(len(x_shape), idx)
|
||||
|
||||
@ -11987,10 +11995,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
# Check for advanced indexing:
|
||||
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
|
||||
|
||||
# Do the advanced indexing axes appear contiguously? If not, NumPy semantics
|
||||
# move the advanced axes to the front.
|
||||
advanced_axes_are_contiguous = False
|
||||
|
||||
advanced_indexes: Sequence[Array | np.ndarray] | None = None
|
||||
|
||||
# The positions of the advanced indexing axes in `idx`.
|
||||
@ -12009,7 +12013,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
|
||||
for e, i, j in advanced_pairs)
|
||||
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
|
||||
advanced_axes_are_contiguous = bool(np.all(np.diff(idx_advanced_axes) == 1))
|
||||
|
||||
x_axis = 0 # Current axis in x.
|
||||
y_axis = 0 # Current axis in y, before collapsing. See below.
|
||||
|
@ -399,6 +399,14 @@ MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [
|
||||
IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)),
|
||||
out_shape=(3,)),
|
||||
]),
|
||||
("EllipsisWithArrayIndices", [
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 1]), ..., np.array([0, 1])),
|
||||
out_shape=(2, 4)),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), np.array([0, 1]), ..., np.array([0, 1])),
|
||||
out_shape=(2, 3)),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), ..., np.array([0, 1]), np.array([0, 1])),
|
||||
out_shape=(3, 2)),
|
||||
]),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user