mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Improved errors when indexing with floats
This commit is contained in:
parent
70024d2201
commit
b2c45b8eb9
@ -751,13 +751,33 @@ def merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
|
||||
def _int(aval):
|
||||
return not aval.shape and dtypes.issubdtype(aval.dtype, np.integer)
|
||||
|
||||
def _aval_or_none(x):
|
||||
try:
|
||||
return core.get_aval(x)
|
||||
except:
|
||||
return None
|
||||
|
||||
def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
normalize_indices: bool = True) -> _Indexer:
|
||||
# Convert sequences to arrays
|
||||
idx = tuple(lax_numpy.asarray(i, dtype=None if i else int)
|
||||
if isinstance(i, Sequence) else i for i in idx)
|
||||
abstract_idx = [_aval_or_none(i) for i in idx]
|
||||
float_indices = [(i, val, aval) for i, (val, aval) in enumerate(zip(idx, abstract_idx))
|
||||
if aval is not None and dtypes.issubdtype(aval, np.inexact)]
|
||||
|
||||
# Check for float or complex indices:
|
||||
if float_indices:
|
||||
i, val, aval = float_indices[0]
|
||||
msg = ("Indexer must have integer or boolean type, got indexer "
|
||||
"with type {} at position {}, indexer value {}")
|
||||
raise TypeError(msg.format(aval.dtype.name, i, val))
|
||||
|
||||
# Check whether advanced indices are contiguous. We must do this before
|
||||
# removing ellipses (https://github.com/jax-ml/jax/issues/25109)
|
||||
# If advanced idexing axes do not appear contiguously, NumPy semantics
|
||||
# move the advanced axes to the front.
|
||||
is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray))
|
||||
is_advanced, = np.nonzero([isinstance(e, (int, np.integer, Array, np.ndarray))
|
||||
or lax_numpy.isscalar(e) for e in idx])
|
||||
advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1)
|
||||
|
||||
@ -862,11 +882,8 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
gather_slice_shape.append(1)
|
||||
continue
|
||||
|
||||
try:
|
||||
abstract_i = core.get_aval(i)
|
||||
except TypeError:
|
||||
abstract_i = None
|
||||
# Handle basic int indexes.
|
||||
abstract_i = _aval_or_none(i)
|
||||
if isinstance(abstract_i, core.ShapedArray) and _int(abstract_i):
|
||||
if core.definitely_equal(x_shape[x_axis], 0):
|
||||
# XLA gives error when indexing into an axis of size 0
|
||||
|
@ -1120,6 +1120,10 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
jnp.zeros(2).at[0.].add(1.)
|
||||
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
||||
jnp.zeros(2).at[0.].set(1.)
|
||||
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
||||
jnp.zeros((2, 2))[jnp.arange(2), 1.0]
|
||||
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
||||
jnp.zeros((2, 2))[jnp.arange(2), 1 + 1j]
|
||||
|
||||
def testStrIndexingError(self):
|
||||
msg = "JAX does not support string indexing"
|
||||
|
Loading…
x
Reference in New Issue
Block a user