mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.numpy: implement scalar boolean indexing
This commit is contained in:
parent
805ed852dc
commit
bbfd4f2c26
@ -12,6 +12,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Added [CUDA Array
|
||||
Interface](https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html)
|
||||
import support (requires jaxlib 0.4.24).
|
||||
* JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`.
|
||||
|
||||
* Deprecations & Removals
|
||||
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
|
||||
|
@ -308,13 +308,6 @@ class ArrayImpl(basearray.Array):
|
||||
from jax._src.numpy import lax_numpy
|
||||
self._check_if_deleted()
|
||||
|
||||
if isinstance(idx, tuple):
|
||||
num_idx = sum(e is not None and e is not Ellipsis for e in idx)
|
||||
if num_idx > self.ndim:
|
||||
raise IndexError(
|
||||
f"Too many indices for array: array has ndim of {self.ndim}, but "
|
||||
f"was indexed with {num_idx} non-None/Ellipsis indices.")
|
||||
|
||||
if isinstance(self.sharding, PmapSharding):
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
cidx = idx if isinstance(idx, tuple) else (idx,)
|
||||
|
@ -4613,6 +4613,9 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
|
||||
if isinstance(fill_value, np.ndarray):
|
||||
fill_value = fill_value.item()
|
||||
|
||||
if indexer.scalar_bool_dims:
|
||||
y = lax.expand_dims(y, indexer.scalar_bool_dims)
|
||||
|
||||
# Avoid calling gather if the slice shape is empty, both as a fast path and to
|
||||
# handle cases like zeros(0)[array([], int32)].
|
||||
if core.is_empty_shape(indexer.slice_shape):
|
||||
@ -4657,6 +4660,10 @@ class _Indexer(NamedTuple):
|
||||
# gathers and eliminated for scatters.
|
||||
newaxis_dims: Sequence[int]
|
||||
|
||||
# Keep track of dimensions with scalar bool indices. These must be inserted
|
||||
# for gathers before performing other index operations.
|
||||
scalar_bool_dims: Sequence[int]
|
||||
|
||||
|
||||
def _split_index_for_jit(idx, shape):
|
||||
"""Splits indices into necessarily-static and dynamic parts.
|
||||
@ -4705,6 +4712,16 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
# Remove ellipses and add trailing slice(None)s.
|
||||
idx = _canonicalize_tuple_index(len(x_shape), idx)
|
||||
|
||||
# Check for scalar boolean indexing: this requires inserting extra dimensions
|
||||
# before performing the rest of the logic.
|
||||
scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)]
|
||||
if scalar_bool_dims:
|
||||
idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i for i in idx)
|
||||
x_shape = list(x_shape)
|
||||
for i in sorted(scalar_bool_dims):
|
||||
x_shape.insert(i, 1)
|
||||
x_shape = tuple(x_shape)
|
||||
|
||||
# Check for advanced indexing:
|
||||
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
|
||||
|
||||
@ -4805,8 +4822,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
# XLA gives error when indexing into an axis of size 0
|
||||
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
|
||||
i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
|
||||
i = lax.convert_element_type(i, index_dtype)
|
||||
gather_indices.append((i, len(gather_indices_shape)))
|
||||
i_converted = lax.convert_element_type(i, index_dtype)
|
||||
gather_indices.append((i_converted, len(gather_indices_shape)))
|
||||
collapsed_slice_dims.append(x_axis)
|
||||
gather_slice_shape.append(1)
|
||||
start_index_map.append(x_axis)
|
||||
@ -4893,7 +4910,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
dnums=dnums,
|
||||
gather_indices=gather_indices_array,
|
||||
unique_indices=advanced_indexes is None,
|
||||
indices_are_sorted=advanced_indexes is None)
|
||||
indices_are_sorted=advanced_indexes is None,
|
||||
scalar_bool_dims=scalar_bool_dims)
|
||||
|
||||
def _should_unpack_list_index(x):
|
||||
"""Helper for _eliminate_deprecated_list_indexing."""
|
||||
@ -4959,7 +4977,7 @@ def _expand_bool_indices(idx, shape):
|
||||
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
|
||||
raise errors.NonConcreteBooleanIndexError(abstract_i)
|
||||
elif _ndim(i) == 0:
|
||||
raise TypeError("JAX arrays do not support boolean scalar indices")
|
||||
out.append(bool(i))
|
||||
else:
|
||||
i_shape = _shape(i)
|
||||
start = len(out) + ellipsis_offset - newaxis_offset
|
||||
@ -5010,10 +5028,10 @@ def _is_scalar(x):
|
||||
|
||||
def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
|
||||
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
|
||||
len_without_none = sum(e is not None and e is not Ellipsis for e in idx)
|
||||
if len_without_none > arr_ndim:
|
||||
num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx)
|
||||
if num_dimensions_consumed > arr_ndim:
|
||||
raise IndexError(
|
||||
f"Too many indices for {array_name}: {len_without_none} "
|
||||
f"Too many indices for {array_name}: {num_dimensions_consumed} "
|
||||
f"non-None/Ellipsis indices for dim {arr_ndim}.")
|
||||
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
|
||||
ellipsis_index = next(ellipses, None)
|
||||
@ -5021,10 +5039,10 @@ def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
|
||||
if next(ellipses, None) is not None:
|
||||
raise IndexError(
|
||||
f"Multiple ellipses (...) not supported: {list(map(type, idx))}.")
|
||||
colons = (slice(None),) * (arr_ndim - len_without_none)
|
||||
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
|
||||
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
|
||||
elif len_without_none < arr_ndim:
|
||||
colons = (slice(None),) * (arr_ndim - len_without_none)
|
||||
elif num_dimensions_consumed < arr_ndim:
|
||||
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
|
||||
idx = tuple(idx) + colons
|
||||
return idx
|
||||
|
||||
|
@ -103,6 +103,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
||||
indexer = jnp._index_to_gather(jnp.shape(x), idx,
|
||||
normalize_indices=normalize_indices)
|
||||
# TODO(jakevdp): implement scalar boolean logic.
|
||||
if indexer.scalar_bool_dims:
|
||||
raise TypeError("Scalar boolean indices are not allowed in scatter.")
|
||||
|
||||
# Avoid calling scatter if the slice shape is empty, both as a fast path and
|
||||
# to handle cases like zeros(0)[array([], int32)].
|
||||
|
@ -1,8 +1,5 @@
|
||||
# Known failures for the array api tests.
|
||||
|
||||
# JAX doesn't yet support scalar boolean indexing
|
||||
array_api_tests/test_array_object.py::test_getitem_masking
|
||||
|
||||
# Test suite attempts in-place mutation:
|
||||
array_api_tests/test_special_cases.py::test_binary
|
||||
array_api_tests/test_special_cases.py::test_iop
|
||||
|
@ -876,15 +876,6 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
i = np.array([True, True, False])
|
||||
self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i))
|
||||
|
||||
def testScalarBooleanIndexingNotImplemented(self):
|
||||
msg = "JAX arrays do not support boolean scalar indices"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
jnp.arange(4)[True]
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
jnp.arange(4)[False]
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
jnp.arange(4)[..., True]
|
||||
|
||||
def testIssue187(self):
|
||||
x = jnp.ones((5, 5))
|
||||
x[[0, 2, 4], [0, 2, 4]] # doesn't crash
|
||||
@ -1033,6 +1024,29 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(2, 3, 4, 5)],
|
||||
idx=[
|
||||
np.index_exp[True],
|
||||
np.index_exp[False],
|
||||
np.index_exp[..., True],
|
||||
np.index_exp[..., False],
|
||||
np.index_exp[0, :2, True],
|
||||
np.index_exp[0, :2, False],
|
||||
np.index_exp[:2, 0, True],
|
||||
np.index_exp[:2, 0, False],
|
||||
np.index_exp[:2, np.array([0, 2]), True],
|
||||
np.index_exp[np.array([1, 0]), :, True],
|
||||
np.index_exp[True, :, True, :, np.array(True)],
|
||||
]
|
||||
)
|
||||
def testScalarBooleanIndexing(self, shape, idx):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, np.int32)]
|
||||
np_fun = lambda x: np.asarray(x)[idx]
|
||||
jnp_fun = lambda x: jnp.asarray(x)[idx]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
def testFloatIndexingError(self):
|
||||
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
|
||||
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
|
||||
@ -1158,8 +1172,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
def testWrongNumberOfIndices(self):
|
||||
with self.assertRaisesRegex(
|
||||
IndexError,
|
||||
"Too many indices for array: array has ndim of 1, "
|
||||
"but was indexed with 2 non-None/Ellipsis indices"):
|
||||
"Too many indices for array: 2 non-None/Ellipsis indices for dim 1."):
|
||||
jnp.zeros(3)[:, 5]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user