Improved errors when indexing with floats

This commit is contained in:
Jake VanderPlas 2025-02-28 15:04:07 -08:00
parent 70024d2201
commit b2c45b8eb9
2 changed files with 26 additions and 5 deletions

View File

@ -751,13 +751,33 @@ def merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
def _int(aval):
return not aval.shape and dtypes.issubdtype(aval.dtype, np.integer)
def _aval_or_none(x):
try:
return core.get_aval(x)
except:
return None
def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
normalize_indices: bool = True) -> _Indexer:
# Convert sequences to arrays
idx = tuple(lax_numpy.asarray(i, dtype=None if i else int)
if isinstance(i, Sequence) else i for i in idx)
abstract_idx = [_aval_or_none(i) for i in idx]
float_indices = [(i, val, aval) for i, (val, aval) in enumerate(zip(idx, abstract_idx))
if aval is not None and dtypes.issubdtype(aval, np.inexact)]
# Check for float or complex indices:
if float_indices:
i, val, aval = float_indices[0]
msg = ("Indexer must have integer or boolean type, got indexer "
"with type {} at position {}, indexer value {}")
raise TypeError(msg.format(aval.dtype.name, i, val))
# Check whether advanced indices are contiguous. We must do this before
# removing ellipses (https://github.com/jax-ml/jax/issues/25109)
# If advanced idexing axes do not appear contiguously, NumPy semantics
# move the advanced axes to the front.
is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray))
is_advanced, = np.nonzero([isinstance(e, (int, np.integer, Array, np.ndarray))
or lax_numpy.isscalar(e) for e in idx])
advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1)
@ -862,11 +882,8 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
gather_slice_shape.append(1)
continue
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
# Handle basic int indexes.
abstract_i = _aval_or_none(i)
if isinstance(abstract_i, core.ShapedArray) and _int(abstract_i):
if core.definitely_equal(x_shape[x_axis], 0):
# XLA gives error when indexing into an axis of size 0

View File

@ -1120,6 +1120,10 @@ class IndexingTest(jtu.JaxTestCase):
jnp.zeros(2).at[0.].add(1.)
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros(2).at[0.].set(1.)
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros((2, 2))[jnp.arange(2), 1.0]
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
jnp.zeros((2, 2))[jnp.arange(2), 1 + 1j]
def testStrIndexingError(self):
msg = "JAX does not support string indexing"