mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve error when indexing with too many indices
This commit is contained in:
parent
7d8b8578b5
commit
9a88ecb244
@ -8286,13 +8286,14 @@ def _is_scalar(x):
|
||||
return np.isscalar(x) or (isinstance(x, (np.ndarray, Array))
|
||||
and np.ndim(x) == 0)
|
||||
|
||||
def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
|
||||
def _canonicalize_tuple_index(arr_ndim, idx):
|
||||
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
|
||||
num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx)
|
||||
if num_dimensions_consumed > arr_ndim:
|
||||
index_or_indices = "index" if num_dimensions_consumed == 1 else "indices"
|
||||
raise IndexError(
|
||||
f"Too many indices for {array_name}: {num_dimensions_consumed} "
|
||||
f"non-None/Ellipsis indices for dim {arr_ndim}.")
|
||||
f"Too many indices: {arr_ndim}-dimensional array indexed "
|
||||
f"with {num_dimensions_consumed} regular {index_or_indices}.")
|
||||
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
|
||||
ellipsis_index = next(ellipses, None)
|
||||
if ellipsis_index is not None:
|
||||
|
@ -1206,7 +1206,11 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
def testWrongNumberOfIndices(self):
|
||||
with self.assertRaisesRegex(
|
||||
IndexError,
|
||||
"Too many indices for array: 2 non-None/Ellipsis indices for dim 1."):
|
||||
"Too many indices: 0-dimensional array indexed with 1 regular index."):
|
||||
jnp.array(1)[0]
|
||||
with self.assertRaisesRegex(
|
||||
IndexError,
|
||||
"Too many indices: 1-dimensional array indexed with 2 regular indices."):
|
||||
jnp.zeros(3)[:, 5]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user