Improve error for indexing with string

This commit is contained in:
Jake VanderPlas 2023-03-20 08:55:16 -07:00
parent 021fadfcbc
commit dd8033bdd4
2 changed files with 10 additions and 1 deletions

View File

@ -4037,6 +4037,8 @@ def _split_index_for_jit(idx, shape):
"""
# Convert list indices to tuples in cases (deprecated by NumPy.)
idx = _eliminate_deprecated_list_indexing(idx)
if any(isinstance(i, str) for i in idx):
raise TypeError(f"JAX does not support string indexing; got {idx=}")
# Expand any (concrete) boolean indices. We can then use advanced integer
# indexing logic to handle them.
@ -4304,7 +4306,7 @@ def _eliminate_deprecated_list_indexing(idx):
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
# objects]". Detects this and raises a TypeError.
if not isinstance(idx, tuple):
if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray)):
if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray, str)):
# As of numpy 1.16, some non-tuple sequences of indices result in a warning, while
# others are converted to arrays, based on a set of somewhat convoluted heuristics
# (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343)

View File

@ -929,6 +929,13 @@ class IndexingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros(2).at[0.].set(1.)
def testStrIndexingError(self):
msg = "JAX does not support string indexing"
with self.assertRaisesRegex(TypeError, msg):
jnp.zeros(2)['abc']
with self.assertRaisesRegex(TypeError, msg):
jnp.zeros(2)[:, 'abc']
def testIndexingPositionalArgumentWarning(self):
x = jnp.arange(4)
with self.assertWarnsRegex(