Improve error when indexing with too many indices

This commit is contained in:
Jake VanderPlas 2024-07-31 13:57:48 -07:00
parent 7d8b8578b5
commit 9a88ecb244
2 changed files with 9 additions and 4 deletions

View File

@ -8286,13 +8286,14 @@ def _is_scalar(x):
return np.isscalar(x) or (isinstance(x, (np.ndarray, Array))
and np.ndim(x) == 0)
def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
def _canonicalize_tuple_index(arr_ndim, idx):
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx)
if num_dimensions_consumed > arr_ndim:
index_or_indices = "index" if num_dimensions_consumed == 1 else "indices"
raise IndexError(
f"Too many indices for {array_name}: {num_dimensions_consumed} "
f"non-None/Ellipsis indices for dim {arr_ndim}.")
f"Too many indices: {arr_ndim}-dimensional array indexed "
f"with {num_dimensions_consumed} regular {index_or_indices}.")
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
ellipsis_index = next(ellipses, None)
if ellipsis_index is not None:

View File

@ -1206,7 +1206,11 @@ class IndexingTest(jtu.JaxTestCase):
def testWrongNumberOfIndices(self):
with self.assertRaisesRegex(
IndexError,
"Too many indices for array: 2 non-None/Ellipsis indices for dim 1."):
"Too many indices: 0-dimensional array indexed with 1 regular index."):
jnp.array(1)[0]
with self.assertRaisesRegex(
IndexError,
"Too many indices: 1-dimensional array indexed with 2 regular indices."):
jnp.zeros(3)[:, 5]