mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Improve error for indexing with string
This commit is contained in:
parent
021fadfcbc
commit
dd8033bdd4
@ -4037,6 +4037,8 @@ def _split_index_for_jit(idx, shape):
|
||||
"""
|
||||
# Convert list indices to tuples in cases (deprecated by NumPy.)
|
||||
idx = _eliminate_deprecated_list_indexing(idx)
|
||||
if any(isinstance(i, str) for i in idx):
|
||||
raise TypeError(f"JAX does not support string indexing; got {idx=}")
|
||||
|
||||
# Expand any (concrete) boolean indices. We can then use advanced integer
|
||||
# indexing logic to handle them.
|
||||
@ -4304,7 +4306,7 @@ def _eliminate_deprecated_list_indexing(idx):
|
||||
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
|
||||
# objects]". Detects this and raises a TypeError.
|
||||
if not isinstance(idx, tuple):
|
||||
if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray)):
|
||||
if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray, str)):
|
||||
# As of numpy 1.16, some non-tuple sequences of indices result in a warning, while
|
||||
# others are converted to arrays, based on a set of somewhat convoluted heuristics
|
||||
# (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343)
|
||||
|
@ -929,6 +929,13 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
||||
jnp.zeros(2).at[0.].set(1.)
|
||||
|
||||
def testStrIndexingError(self):
|
||||
msg = "JAX does not support string indexing"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
jnp.zeros(2)['abc']
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
jnp.zeros(2)[:, 'abc']
|
||||
|
||||
def testIndexingPositionalArgumentWarning(self):
|
||||
x = jnp.arange(4)
|
||||
with self.assertWarnsRegex(
|
||||
|
Loading…
x
Reference in New Issue
Block a user