mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Validate shapes for boolean indices
This commit is contained in:
parent
10bbd628e9
commit
08e1c831ba
@ -4922,6 +4922,7 @@ def _unique_axis_sorted_mask(ar, axis):
|
||||
size, *out_shape = aux.shape
|
||||
aux = aux.reshape(size, _prod(out_shape)).T
|
||||
if aux.shape[0] == 0:
|
||||
size = 1
|
||||
perm = zeros(1, dtype=int)
|
||||
else:
|
||||
perm = lexsort(aux[::-1])
|
||||
@ -5005,7 +5006,7 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False):
|
||||
# All supported cases of indexing can be implemented as an XLA gather,
|
||||
# followed by an optional reverse and broadcast_in_dim.
|
||||
arr = asarray(arr)
|
||||
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
|
||||
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
|
||||
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
|
||||
unique_indices)
|
||||
|
||||
@ -5065,7 +5066,7 @@ _Indexer = collections.namedtuple("_Indexer", [
|
||||
"newaxis_dims",
|
||||
])
|
||||
|
||||
def _split_index_for_jit(idx):
|
||||
def _split_index_for_jit(idx, shape):
|
||||
"""Splits indices into necessarily-static and dynamic parts.
|
||||
|
||||
Used to pass indices into `jit`-ted function.
|
||||
@ -5075,7 +5076,7 @@ def _split_index_for_jit(idx):
|
||||
|
||||
# Expand any (concrete) boolean indices. We can then use advanced integer
|
||||
# indexing logic to handle them.
|
||||
idx = _expand_bool_indices(idx)
|
||||
idx = _expand_bool_indices(idx, shape)
|
||||
|
||||
leaves, treedef = tree_flatten(idx)
|
||||
dynamic = [None] * len(leaves)
|
||||
@ -5328,16 +5329,16 @@ def _eliminate_deprecated_list_indexing(idx):
|
||||
idx = (idx,)
|
||||
return idx
|
||||
|
||||
def _expand_bool_indices(idx):
|
||||
def _expand_bool_indices(idx, shape):
|
||||
"""Converts concrete bool indexes into advanced integer indexes."""
|
||||
out = []
|
||||
for i in idx:
|
||||
for dim_number, i in enumerate(idx):
|
||||
try:
|
||||
abstract_i = core.get_aval(i)
|
||||
except TypeError:
|
||||
abstract_i = None
|
||||
if (isinstance(abstract_i, ShapedArray) and issubdtype(abstract_i.dtype, bool_)
|
||||
or isinstance(i, list) and _all(_is_scalar(e) and issubdtype(_dtype(e), np.bool_) for e in i)):
|
||||
or isinstance(i, list) and i and _all(_is_scalar(e) and issubdtype(_dtype(e), np.bool_) for e in i)):
|
||||
if isinstance(i, list):
|
||||
i = array(i)
|
||||
abstract_i = core.get_aval(i)
|
||||
@ -5346,6 +5347,11 @@ def _expand_bool_indices(idx):
|
||||
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
|
||||
raise errors.NonConcreteBooleanIndexError(abstract_i)
|
||||
else:
|
||||
i_shape = _shape(i)
|
||||
expected_shape = shape[len(out): len(out) + _ndim(i)]
|
||||
if i_shape != expected_shape:
|
||||
raise IndexError("boolean index did not match shape of indexed array in index "
|
||||
f"{dim_number}: got {i_shape}, expected {expected_shape}")
|
||||
out.extend(np.where(i))
|
||||
else:
|
||||
out.append(i)
|
||||
|
@ -64,7 +64,7 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
|
||||
y = jnp.asarray(y)
|
||||
# XLA gathers and scatters are very similar in structure; the scatter logic
|
||||
# is more or less a transpose of the gather equivalent.
|
||||
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx)
|
||||
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)
|
||||
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
indices_are_sorted, unique_indices, normalize_indices)
|
||||
|
||||
|
@ -833,6 +833,25 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
expected = np.array([-1])[np.array([False])]
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testBooleanIndexingShapeMismatch(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7329
|
||||
x = jnp.arange(4)
|
||||
idx = jnp.array([True, False])
|
||||
with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"):
|
||||
x[idx]
|
||||
|
||||
def testNontrivialBooleanIndexing(self):
|
||||
# Test nontrivial corner case in boolean indexing shape validation
|
||||
rng = jtu.rand_default(self.rng())
|
||||
index = (rng((2, 3), np.bool_), rng((6,), np.bool_))
|
||||
|
||||
args_maker = lambda: [rng((2, 3, 6), np.int32)]
|
||||
np_fun = lambda x: np.asarray(x)[index]
|
||||
jnp_fun = lambda x: jnp.asarray(x)[index]
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testFloatIndexingError(self):
|
||||
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
|
||||
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
||||
|
Loading…
x
Reference in New Issue
Block a user