Validate shapes for boolean indices

This commit is contained in:
Jake VanderPlas 2021-08-03 09:51:52 -07:00
parent 10bbd628e9
commit 08e1c831ba
3 changed files with 32 additions and 7 deletions

View File

@ -4922,6 +4922,7 @@ def _unique_axis_sorted_mask(ar, axis):
size, *out_shape = aux.shape
aux = aux.reshape(size, _prod(out_shape)).T
if aux.shape[0] == 0:
size = 1
perm = zeros(1, dtype=int)
else:
perm = lexsort(aux[::-1])
@ -5005,7 +5006,7 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False):
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
arr = asarray(arr)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices)
@ -5065,7 +5066,7 @@ _Indexer = collections.namedtuple("_Indexer", [
"newaxis_dims",
])
def _split_index_for_jit(idx):
def _split_index_for_jit(idx, shape):
"""Splits indices into necessarily-static and dynamic parts.
Used to pass indices into `jit`-ted function.
@ -5075,7 +5076,7 @@ def _split_index_for_jit(idx):
# Expand any (concrete) boolean indices. We can then use advanced integer
# indexing logic to handle them.
idx = _expand_bool_indices(idx)
idx = _expand_bool_indices(idx, shape)
leaves, treedef = tree_flatten(idx)
dynamic = [None] * len(leaves)
@ -5328,16 +5329,16 @@ def _eliminate_deprecated_list_indexing(idx):
idx = (idx,)
return idx
def _expand_bool_indices(idx):
def _expand_bool_indices(idx, shape):
"""Converts concrete bool indexes into advanced integer indexes."""
out = []
for i in idx:
for dim_number, i in enumerate(idx):
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
if (isinstance(abstract_i, ShapedArray) and issubdtype(abstract_i.dtype, bool_)
or isinstance(i, list) and _all(_is_scalar(e) and issubdtype(_dtype(e), np.bool_) for e in i)):
or isinstance(i, list) and i and _all(_is_scalar(e) and issubdtype(_dtype(e), np.bool_) for e in i)):
if isinstance(i, list):
i = array(i)
abstract_i = core.get_aval(i)
@ -5346,6 +5347,11 @@ def _expand_bool_indices(idx):
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(abstract_i)
else:
i_shape = _shape(i)
expected_shape = shape[len(out): len(out) + _ndim(i)]
if i_shape != expected_shape:
raise IndexError("boolean index did not match shape of indexed array in index "
f"{dim_number}: got {i_shape}, expected {expected_shape}")
out.extend(np.where(i))
else:
out.append(i)

View File

@ -64,7 +64,7 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
y = jnp.asarray(y)
# XLA gathers and scatters are very similar in structure; the scatter logic
# is more or less a transpose of the gather equivalent.
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx)
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
indices_are_sorted, unique_indices, normalize_indices)

View File

@ -833,6 +833,25 @@ class IndexingTest(jtu.JaxTestCase):
expected = np.array([-1])[np.array([False])]
self.assertAllClose(ans, expected, check_dtypes=False)
def testBooleanIndexingShapeMismatch(self):
# Regression test for https://github.com/google/jax/issues/7329
x = jnp.arange(4)
idx = jnp.array([True, False])
with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"):
x[idx]
def testNontrivialBooleanIndexing(self):
# Test nontrivial corner case in boolean indexing shape validation
rng = jtu.rand_default(self.rng())
index = (rng((2, 3), np.bool_), rng((6,), np.bool_))
args_maker = lambda: [rng((2, 3, 6), np.int32)]
np_fun = lambda x: np.asarray(x)[index]
jnp_fun = lambda x: jnp.asarray(x)[index]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testFloatIndexingError(self):
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):