jax.numpy: implement scalar boolean indexing

This commit is contained in:
Jake VanderPlas 2024-02-08 14:11:15 -08:00
parent 805ed852dc
commit bbfd4f2c26
6 changed files with 56 additions and 31 deletions

View File

@ -12,6 +12,7 @@ Remember to align the itemized text with the first line of an item within a list
* Added [CUDA Array
Interface](https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html)
import support (requires jaxlib 0.4.24).
* JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`.
* Deprecations & Removals
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D

View File

@ -308,13 +308,6 @@ class ArrayImpl(basearray.Array):
from jax._src.numpy import lax_numpy
self._check_if_deleted()
if isinstance(idx, tuple):
num_idx = sum(e is not None and e is not Ellipsis for e in idx)
if num_idx > self.ndim:
raise IndexError(
f"Too many indices for array: array has ndim of {self.ndim}, but "
f"was indexed with {num_idx} non-None/Ellipsis indices.")
if isinstance(self.sharding, PmapSharding):
if config.pmap_no_rank_reduction.value:
cidx = idx if isinstance(idx, tuple) else (idx,)

View File

@ -4613,6 +4613,9 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
if isinstance(fill_value, np.ndarray):
fill_value = fill_value.item()
if indexer.scalar_bool_dims:
y = lax.expand_dims(y, indexer.scalar_bool_dims)
# Avoid calling gather if the slice shape is empty, both as a fast path and to
# handle cases like zeros(0)[array([], int32)].
if core.is_empty_shape(indexer.slice_shape):
@ -4657,6 +4660,10 @@ class _Indexer(NamedTuple):
# gathers and eliminated for scatters.
newaxis_dims: Sequence[int]
# Keep track of dimensions with scalar bool indices. These must be inserted
# for gathers before performing other index operations.
scalar_bool_dims: Sequence[int]
def _split_index_for_jit(idx, shape):
"""Splits indices into necessarily-static and dynamic parts.
@ -4705,6 +4712,16 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
# Remove ellipses and add trailing slice(None)s.
idx = _canonicalize_tuple_index(len(x_shape), idx)
# Check for scalar boolean indexing: this requires inserting extra dimensions
# before performing the rest of the logic.
scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)]
if scalar_bool_dims:
idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i for i in idx)
x_shape = list(x_shape)
for i in sorted(scalar_bool_dims):
x_shape.insert(i, 1)
x_shape = tuple(x_shape)
# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
@ -4805,8 +4822,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
i = lax.convert_element_type(i, index_dtype)
gather_indices.append((i, len(gather_indices_shape)))
i_converted = lax.convert_element_type(i, index_dtype)
gather_indices.append((i_converted, len(gather_indices_shape)))
collapsed_slice_dims.append(x_axis)
gather_slice_shape.append(1)
start_index_map.append(x_axis)
@ -4893,7 +4910,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
dnums=dnums,
gather_indices=gather_indices_array,
unique_indices=advanced_indexes is None,
indices_are_sorted=advanced_indexes is None)
indices_are_sorted=advanced_indexes is None,
scalar_bool_dims=scalar_bool_dims)
def _should_unpack_list_index(x):
"""Helper for _eliminate_deprecated_list_indexing."""
@ -4959,7 +4977,7 @@ def _expand_bool_indices(idx, shape):
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(abstract_i)
elif _ndim(i) == 0:
raise TypeError("JAX arrays do not support boolean scalar indices")
out.append(bool(i))
else:
i_shape = _shape(i)
start = len(out) + ellipsis_offset - newaxis_offset
@ -5010,10 +5028,10 @@ def _is_scalar(x):
def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
len_without_none = sum(e is not None and e is not Ellipsis for e in idx)
if len_without_none > arr_ndim:
num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx)
if num_dimensions_consumed > arr_ndim:
raise IndexError(
f"Too many indices for {array_name}: {len_without_none} "
f"Too many indices for {array_name}: {num_dimensions_consumed} "
f"non-None/Ellipsis indices for dim {arr_ndim}.")
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
ellipsis_index = next(ellipses, None)
@ -5021,10 +5039,10 @@ def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
if next(ellipses, None) is not None:
raise IndexError(
f"Multiple ellipses (...) not supported: {list(map(type, idx))}.")
colons = (slice(None),) * (arr_ndim - len_without_none)
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
elif len_without_none < arr_ndim:
colons = (slice(None),) * (arr_ndim - len_without_none)
elif num_dimensions_consumed < arr_ndim:
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
idx = tuple(idx) + colons
return idx

View File

@ -103,6 +103,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx,
normalize_indices=normalize_indices)
# TODO(jakevdp): implement scalar boolean logic.
if indexer.scalar_bool_dims:
raise TypeError("Scalar boolean indices are not allowed in scatter.")
# Avoid calling scatter if the slice shape is empty, both as a fast path and
# to handle cases like zeros(0)[array([], int32)].

View File

@ -1,8 +1,5 @@
# Known failures for the array api tests.
# JAX doesn't yet support scalar boolean indexing
array_api_tests/test_array_object.py::test_getitem_masking
# Test suite attempts in-place mutation:
array_api_tests/test_special_cases.py::test_binary
array_api_tests/test_special_cases.py::test_iop

View File

@ -876,15 +876,6 @@ class IndexingTest(jtu.JaxTestCase):
i = np.array([True, True, False])
self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i))
def testScalarBooleanIndexingNotImplemented(self):
msg = "JAX arrays do not support boolean scalar indices"
with self.assertRaisesRegex(TypeError, msg):
jnp.arange(4)[True]
with self.assertRaisesRegex(TypeError, msg):
jnp.arange(4)[False]
with self.assertRaisesRegex(TypeError, msg):
jnp.arange(4)[..., True]
def testIssue187(self):
x = jnp.ones((5, 5))
x[[0, 2, 4], [0, 2, 4]] # doesn't crash
@ -1033,6 +1024,29 @@ class IndexingTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shape=[(2, 3, 4, 5)],
idx=[
np.index_exp[True],
np.index_exp[False],
np.index_exp[..., True],
np.index_exp[..., False],
np.index_exp[0, :2, True],
np.index_exp[0, :2, False],
np.index_exp[:2, 0, True],
np.index_exp[:2, 0, False],
np.index_exp[:2, np.array([0, 2]), True],
np.index_exp[np.array([1, 0]), :, True],
np.index_exp[True, :, True, :, np.array(True)],
]
)
def testScalarBooleanIndexing(self, shape, idx):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, np.int32)]
np_fun = lambda x: np.asarray(x)[idx]
jnp_fun = lambda x: jnp.asarray(x)[idx]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
def testFloatIndexingError(self):
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
@ -1158,8 +1172,7 @@ class IndexingTest(jtu.JaxTestCase):
def testWrongNumberOfIndices(self):
with self.assertRaisesRegex(
IndexError,
"Too many indices for array: array has ndim of 1, "
"but was indexed with 2 non-None/Ellipsis indices"):
"Too many indices for array: 2 non-None/Ellipsis indices for dim 1."):
jnp.zeros(3)[:, 5]