From b2c45b8eb9f5e9db0448cf19d333cebab41b35aa Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Feb 2025 15:04:07 -0800 Subject: [PATCH] Improved errors when indexing with floats --- jax/_src/numpy/indexing.py | 27 ++++++++++++++++++++++----- tests/lax_numpy_indexing_test.py | 4 ++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index a402fb8dc..5d59bb53b 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -751,13 +751,33 @@ def merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx): def _int(aval): return not aval.shape and dtypes.issubdtype(aval.dtype, np.integer) +def _aval_or_none(x): + try: + return core.get_aval(x) + except: + return None + def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], normalize_indices: bool = True) -> _Indexer: + # Convert sequences to arrays + idx = tuple(lax_numpy.asarray(i, dtype=None if i else int) + if isinstance(i, Sequence) else i for i in idx) + abstract_idx = [_aval_or_none(i) for i in idx] + float_indices = [(i, val, aval) for i, (val, aval) in enumerate(zip(idx, abstract_idx)) + if aval is not None and dtypes.issubdtype(aval, np.inexact)] + + # Check for float or complex indices: + if float_indices: + i, val, aval = float_indices[0] + msg = ("Indexer must have integer or boolean type, got indexer " + "with type {} at position {}, indexer value {}") + raise TypeError(msg.format(aval.dtype.name, i, val)) + # Check whether advanced indices are contiguous. We must do this before # removing ellipses (https://github.com/jax-ml/jax/issues/25109) # If advanced idexing axes do not appear contiguously, NumPy semantics # move the advanced axes to the front. - is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray)) + is_advanced, = np.nonzero([isinstance(e, (int, np.integer, Array, np.ndarray)) or lax_numpy.isscalar(e) for e in idx]) advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1) @@ -862,11 +882,8 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], gather_slice_shape.append(1) continue - try: - abstract_i = core.get_aval(i) - except TypeError: - abstract_i = None # Handle basic int indexes. + abstract_i = _aval_or_none(i) if isinstance(abstract_i, core.ShapedArray) and _int(abstract_i): if core.definitely_equal(x_shape[x_axis], 0): # XLA gives error when indexing into an axis of size 0 diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 04225d6c5..63a725ad3 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1120,6 +1120,10 @@ class IndexingTest(jtu.JaxTestCase): jnp.zeros(2).at[0.].add(1.) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros(2).at[0.].set(1.) + with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): + jnp.zeros((2, 2))[jnp.arange(2), 1.0] + with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): + jnp.zeros((2, 2))[jnp.arange(2), 1 + 1j] def testStrIndexingError(self): msg = "JAX does not support string indexing"