rocm_jax/jax/_src/numpy/indexing.py
Peter Hawkins 1e5d9a9158 Add an allow_negative_indices option to lax.dynamic_slice and lax.dynamic_update_slice.
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.

PiperOrigin-RevId: 731812827
2025-02-27 12:04:28 -08:00

1268 lines
49 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
"""Indexing code for jax.numpy."""
from functools import partial
import operator
import string
from typing import Any, NamedTuple, Sequence
import numpy as np
import jax
from jax import lax
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
from jax._src.api import jit
from jax._src.lax import lax as lax_internal
from jax._src.numpy import einsum
from jax._src import mesh as mesh_lib
from jax._src.pjit import auto_axes
from jax._src.numpy import lax_numpy
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.tree_util import tree_flatten
from jax._src.typing import Array, ArrayLike, StaticScalar
from jax._src.util import canonicalize_axis, set_module, tuple_replace, safe_zip
export = set_module('jax.numpy')
@export
def take(
a: ArrayLike,
indices: ArrayLike,
axis: int | None = None,
out: None = None,
mode: str | None = None,
unique_indices: bool = False,
indices_are_sorted: bool = False,
fill_value: StaticScalar | None = None,
) -> Array:
"""Take elements from an array.
JAX implementation of :func:`numpy.take`, implemented in terms of
:func:`jax.lax.gather`. JAX's behavior differs from NumPy in the case
of out-of-bound indices; see the ``mode`` parameter below.
Args:
a: array from which to take values.
indices: N-dimensional array of integer indices of values to take from the array.
axis: the axis along which to take values. If not specified, the array will
be flattened before indexing is applied.
mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default
``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices;
the ``fill_value`` argument gives control over this value. For more discussion
of ``mode`` options, see :attr:`jax.numpy.ndarray.at`.
fill_value: The fill value to return for out-of-bounds slices when mode is 'fill'.
Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for
signed types, the largest positive value for unsigned types, and True for booleans.
unique_indices: If True, the implementation will assume that the indices are unique,
which can result in more efficient execution on some backends. If set to True and
indices are not unique, the output is undefined.
indices_are_sorted : If True, the implementation will assume that the indices are
sorted in ascending order, which can lead to more efficient execution on some
backends. If set to True and indices are not sorted, the output is undefined.
Returns:
Array of values extracted from ``a``.
See also:
- :attr:`jax.numpy.ndarray.at`: take values via indexing syntax.
- :func:`jax.numpy.take_along_axis`: take values along an axis
Examples:
>>> x = jnp.array([[1., 2., 3.],
... [4., 5., 6.]])
>>> indices = jnp.array([2, 0])
Passing no axis results in indexing into the flattened array:
>>> jnp.take(x, indices)
Array([3., 1.], dtype=float32)
>>> x.ravel()[indices] # equivalent indexing syntax
Array([3., 1.], dtype=float32)
Passing an axis results ind applying the index to every subarray along the axis:
>>> jnp.take(x, indices, axis=1)
Array([[3., 1.],
[6., 4.]], dtype=float32)
>>> x[:, indices] # equivalent indexing syntax
Array([[3., 1.],
[6., 4.]], dtype=float32)
Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`:
>>> jnp.take(x, indices, axis=0)
Array([[nan, nan, nan],
[ 1., 2., 3.]], dtype=float32)
>>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax
Array([[nan, nan, nan],
[ 1., 2., 3.]], dtype=float32)
This default out-of-bound behavior can be adjusted using the ``mode`` parameter, for
example, we can instead clip to the last valid value:
>>> jnp.take(x, indices, axis=0, mode='clip')
Array([[4., 5., 6.],
[1., 2., 3.]], dtype=float32)
>>> x.at[indices].get(mode='clip') # equivalent indexing syntax
Array([[4., 5., 6.],
[1., 2., 3.]], dtype=float32)
"""
return _take(a, indices, None if axis is None else operator.index(axis), out,
mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value)
@partial(jit, static_argnames=('axis', 'mode', 'unique_indices', 'indices_are_sorted', 'fill_value'))
def _take(a, indices, axis: int | None = None, out=None, mode=None,
unique_indices=False, indices_are_sorted=False, fill_value=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
a, indices = util.ensure_arraylike("take", a, indices)
if axis is None:
a = a.ravel()
axis_idx = 0
else:
axis_idx = canonicalize_axis(axis, np.ndim(a))
if mode is None or mode == "fill":
gather_mode = lax.GatherScatterMode.FILL_OR_DROP
# lax.gather() does not support negative indices, so we wrap them here
indices = util._where(indices < 0, indices + a.shape[axis_idx], indices)
elif mode == "raise":
# TODO(phawkins): we have no way to report out of bounds errors yet.
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
elif mode == "wrap":
indices = ufuncs.mod(indices, lax_internal._const(indices, a.shape[axis_idx]))
gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS
elif mode == "clip":
gather_mode = lax.GatherScatterMode.CLIP
else:
raise ValueError(f"Invalid mode '{mode}' for np.take")
index_dims = len(np.shape(indices))
slice_sizes = list(np.shape(a))
if slice_sizes[axis_idx] == 0:
if indices.size != 0:
raise IndexError("Cannot do a non-empty jnp.take() from an empty axis.")
return a
if indices.size == 0:
out_shape = (slice_sizes[:axis_idx] + list(indices.shape) +
slice_sizes[axis_idx + 1:])
return lax.full_like(a, 0, shape=out_shape)
slice_sizes[axis_idx] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(
list(range(axis_idx)) +
list(range(axis_idx + index_dims, len(a.shape) + index_dims - 1))),
collapsed_slice_dims=(axis_idx,),
start_index_map=(axis_idx,))
return lax.gather(a, indices[..., None], dimension_numbers=dnums,
slice_sizes=tuple(slice_sizes),
mode=gather_mode, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value)
def _normalize_index(index, axis_size):
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
if dtypes.issubdtype(dtypes.dtype(index, canonicalize=True), np.unsignedinteger):
return index
if core.is_constant_dim(axis_size):
axis_size_val = lax_internal._const(index, axis_size)
else:
axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size),
dtypes.dtype(index, canonicalize=True))
if isinstance(index, (int, np.integer)):
return lax.add(index, axis_size_val) if index < 0 else index
else:
return lax.select(index < 0, lax.add(index, axis_size_val), index)
@export
@partial(jit, static_argnames=('axis', 'mode', 'fill_value'))
def take_along_axis(
arr: ArrayLike,
indices: ArrayLike,
axis: int | None,
mode: str | lax.GatherScatterMode | None = None,
fill_value: StaticScalar | None = None,
) -> Array:
"""Take elements from an array.
JAX implementation of :func:`numpy.take_along_axis`, implemented in
terms of :func:`jax.lax.gather`. JAX's behavior differs from NumPy
in the case of out-of-bound indices; see the ``mode`` parameter below.
Args:
a: array from which to take values.
indices: array of integer indices. If ``axis`` is ``None``, must be one-dimensional.
If ``axis`` is not None, must have ``a.ndim == indices.ndim``, and ``a`` must be
broadcast-compatible with ``indices`` along dimensions other than ``axis``.
axis: the axis along which to take values. If not specified, the array will
be flattened before indexing is applied.
mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default
``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices.
For more discussion of ``mode`` options, see :attr:`jax.numpy.ndarray.at`.
Returns:
Array of values extracted from ``a``.
See also:
- :attr:`jax.numpy.ndarray.at`: take values via indexing syntax.
- :func:`jax.numpy.take`: take the same indices along every axis slice.
Examples:
>>> x = jnp.array([[1., 2., 3.],
... [4., 5., 6.]])
>>> indices = jnp.array([[0, 2],
... [1, 0]])
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1., 3.],
[5., 4.]], dtype=float32)
>>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax
Array([[1., 3.],
[5., 4.]], dtype=float32)
Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`:
>>> indices = jnp.array([[1, 0, 2]])
>>> jnp.take_along_axis(x, indices, axis=0)
Array([[ 4., 2., nan]], dtype=float32)
>>> x.at[indices, jnp.arange(3)].get(
... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax
Array([[ 4., 2., nan]], dtype=float32)
``take_along_axis`` is helpful for extracting values from multi-dimensional
argsorts and arg reductions. For, here we compute :func:`~jax.numpy.argsort`
indices along an axis, and use ``take_along_axis`` to construct the sorted
array:
>>> x = jnp.array([[5, 3, 4],
... [2, 7, 6]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 2, 0],
[0, 2, 1]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[3, 4, 5],
[2, 6, 7]], dtype=int32)
Similarly, we can use :func:`~jax.numpy.argmin` with ``keepdims=True`` and
use ``take_along_axis`` to extract the minimum value:
>>> idx = jnp.argmin(x, axis=1, keepdims=True)
>>> idx
Array([[1],
[0]], dtype=int32)
>>> jnp.take_along_axis(x, idx, axis=1)
Array([[3],
[2]], dtype=int32)
"""
a, indices = util.ensure_arraylike("take_along_axis", arr, indices)
index_dtype = dtypes.dtype(indices)
idx_shape = np.shape(indices)
if not dtypes.issubdtype(index_dtype, np.integer):
raise TypeError("take_along_axis indices must be of integer type, got "
f"{index_dtype}")
if axis is None:
if np.ndim(indices) != 1:
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
raise ValueError(msg.format(idx_shape))
a = a.ravel()
axis = 0
rank = a.ndim
if rank != np.ndim(indices):
msg = "indices and arr must have the same number of dimensions; {} vs. {}"
raise ValueError(msg.format(np.ndim(indices), a.ndim))
axis_int = canonicalize_axis(axis, rank)
def replace(tup, val):
lst = list(tup)
lst[axis_int] = val
return tuple(lst)
use_64bit_index = any(not core.is_constant_dim(d) or d >= (1 << 31) for d in a.shape)
index_dtype = np.dtype('int64' if use_64bit_index else 'int32')
indices = lax.convert_element_type(indices, index_dtype)
axis_size = a.shape[axis_int]
arr_shape = replace(a.shape, 1)
out_shape = lax.broadcast_shapes(idx_shape, arr_shape)
if axis_size == 0:
return lax.full(out_shape, 0, a.dtype)
if mode == "one_hot":
indices = _normalize_index(indices, axis_size)
hot = jax.nn.one_hot(indices, axis_size, dtype=np.bool_)
if a.ndim == 1:
return einsum.einsum("...b,b->...", hot, a, preferred_element_type=a.dtype)
if axis_int > len(string.ascii_letters) - 2:
raise ValueError(
"One Hot indexing is only supported for up to 50 leading dimensions."
)
labels = "".join([string.ascii_letters[i] for i in range(axis_int)])
eq = labels + "y...z," + labels + "z...->" + labels + "y..."
return einsum.einsum(
eq,
hot,
a,
precision=lax.Precision.HIGHEST,
preferred_element_type=a.dtype,
)
index_dims = [i for i, idx in enumerate(idx_shape) if i == axis_int or not core.definitely_equal(idx, 1)]
gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,)
gather_indices = []
slice_sizes = []
offset_dims = []
start_index_map = []
collapsed_slice_dims = []
operand_batching_dims = []
start_indices_batching_dims = []
# We will squeeze the array. i is the index of the unsqueezed shape, while
# new_i is the index of the squeezed shape. j is the index of the gather
# indices.
dims_to_squeeze = []
new_i = 0
j = 0
for i in range(rank):
if i == axis_int:
if mode != 'promise_in_bounds':
indices = _normalize_index(indices, axis_size)
gather_indices.append(lax.reshape(indices, gather_index_shape))
slice_sizes.append(1)
start_index_map.append(new_i)
collapsed_slice_dims.append(new_i)
new_i += 1
j += 1
elif core.definitely_equal(idx_shape[i], 1):
# If idx_shape[i] == 1, we can just take the entirety of the arr's axis
# and avoid forming an iota index.
offset_dims.append(i)
slice_sizes.append(arr_shape[i])
new_i += 1
elif core.definitely_equal(arr_shape[i], 1):
# If the array dimension is 1 but the index dimension is not, we will
# squeeze this dimension.
dims_to_squeeze.append(i)
j += 1
else:
# Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both
# array and index as batching so corresponding elements are gathered.
if core.definitely_equal(arr_shape[i], 0):
slice_sizes.append(0)
else:
slice_sizes.append(1)
operand_batching_dims.append(new_i)
start_indices_batching_dims.append(j)
new_i += 1
j += 1
# Squeeze a to remove singleton dimensions.
a = lax.squeeze(a, dims_to_squeeze)
gather_indices_arr = lax.concatenate(gather_indices, dimension=j)
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(offset_dims),
collapsed_slice_dims=tuple(collapsed_slice_dims),
start_index_map=tuple(start_index_map),
operand_batching_dims=tuple(operand_batching_dims),
start_indices_batching_dims=tuple(start_indices_batching_dims))
return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes),
mode="fill" if mode is None else mode, fill_value=fill_value)
def _make_along_axis_idx(shape, indices, axis):
return tuple_replace(lax_numpy.indices(shape, sparse=True), axis, indices)
@export
@partial(jit, static_argnames=('axis', 'inplace', 'mode'))
def put_along_axis(
arr: ArrayLike,
indices: ArrayLike,
values: ArrayLike,
axis: int | None,
inplace: bool = True,
*,
mode: str | None = None,
) -> Array:
"""Put values into the destination array by matching 1d index and data slices.
JAX implementation of :func:`numpy.put_along_axis`.
The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
arr: array into which values will be put.
indices: array of indices at which to put values.
values: array of values to put into the array.
axis: the axis along which to put values. If not specified, the array will
be flattened before indexing is applied.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options,
see :attr:`jax.numpy.ndarray.at`.
Returns:
A copy of ``a`` with specified entries updated.
See Also:
- :func:`jax.numpy.put`: put elements into an array at given indices.
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
- :func:`jax.numpy.take`: extract values from an array at given indices.
- :func:`jax.numpy.take_along_axis`: extract values from an array along an axis.
Examples:
>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> print(i)
[[1]
[0]]
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
>>> print(b)
[[10 99 20]
[99 40 50]]
"""
if inplace:
raise ValueError(
"jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays"
"are immutable. Pass inplace=False to instead return an updated array.")
arr, indices, values = util.ensure_arraylike("put_along_axis", arr, indices, values)
original_axis = axis
original_arr_shape = arr.shape
if axis is None:
arr = arr.ravel()
axis = 0
if not arr.ndim == indices.ndim:
raise ValueError(
"put_along_axis arguments 'arr' and 'indices' must have same ndim. Got "
f"{arr.ndim=} and {indices.ndim=}."
)
try:
values = util._broadcast_to(values, indices.shape)
except ValueError:
raise ValueError(
"put_along_axis argument 'values' must be broadcastable to 'indices'. Got "
f"{values.shape=} and {indices.shape=}."
)
idx = _make_along_axis_idx(arr.shape, indices, axis)
result = arr.at[idx].set(values, mode=mode)
if original_axis is None:
result = result.reshape(original_arr_shape)
return result
### Indexing
def _is_integer_index(idx: Any) -> bool:
return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_))
def _is_simple_reverse_slice(idx: Any) -> bool:
return (isinstance(idx, slice) and
idx.start is idx.stop is None and
isinstance(idx.step, int) and idx.step == -1)
def _is_valid_integer_index_for_slice(idx, size, mode):
if size == 0:
return False
if _is_integer_index(idx):
return -size <= idx < size
try:
shape, dtype = np.shape(idx), dtypes.dtype(idx, canonicalize=True)
except:
return False
if shape == () and np.issubdtype(dtype, np.integer):
# For dynamic integer indices, semantics require promise_inbounds.
return mode in [None, 'promise_inbounds']
return False
def _is_contiguous_slice(idx):
return (isinstance(idx, slice) and
(idx.start is None or _is_integer_index(idx.start)) and
(idx.stop is None or _is_integer_index(idx.stop)) and
(idx.step is None or (_is_integer_index(idx.step) and idx.step == 1)))
def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> Array | None:
# attempt to compute _rewriting_take via lax.slice(); return None if not possible.
idx = idx if isinstance(idx, tuple) else (idx,)
if not all(isinstance(i, int) for i in arr.shape):
return None
if len(idx) > arr.ndim:
return None
if any(i is None for i in idx):
return None # TODO(jakevdp): handle newaxis case
# For symbolic dimensions fallback to gather
if any(core.is_symbolic_dim(elt)
for i in idx if isinstance(i, slice)
for elt in (i.start, i.stop, i.step)):
return None
if any(i is Ellipsis for i in idx):
# Remove ellipses and add trailing `slice(None)`.
idx = _canonicalize_tuple_index(arr.ndim, idx=idx)
simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)}
int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape))
if _is_valid_integer_index_for_slice(ind, size, mode)}
contiguous_slices = {i for i, ind in enumerate(idx) if _is_contiguous_slice(ind)}
# For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as
# opposed to x[:]) lead to incorrect sharding semantics when computed via
# dynamic_slice, so we fall back to gather.
# TODO(yashkatariya): fix dynamic_slice with sharding
is_sharded = (isinstance(arr, array.ArrayImpl) and
not dispatch.is_single_device_sharding(arr.sharding))
has_partial_slices = any(idx[i].indices(arr.shape[i]) != (0, arr.shape[i], 1)
for i in contiguous_slices)
if is_sharded and (int_indices or has_partial_slices):
return None
if len(simple_revs) + len(int_indices) + len(contiguous_slices) != len(idx):
return None
if simple_revs:
arr = lax.rev(arr, tuple(simple_revs))
idx = tuple(slice(None) if i in simple_revs else ind
for i, ind in enumerate(idx))
contiguous_slices |= simple_revs
if not (int_indices or has_partial_slices):
return arr
idx += (arr.ndim - len(idx)) * (slice(None),)
start_indices: Sequence[ArrayLike] = []
slice_sizes: Sequence[int] = []
allow_negative_indices: list[bool] = []
for ind, size in safe_zip(idx, arr.shape):
if isinstance(ind, slice):
start, stop, step = ind.indices(size)
assert step == 1 # checked above
start_indices.append(start)
slice_sizes.append(max(0, stop - start))
allow_negative_indices.append(start < 0 or stop < 0)
else:
assert np.issubdtype(dtypes.dtype(ind), np.integer) # checked above
assert np.shape(ind) == () # checked above
start_indices.append(ind)
slice_sizes.append(1)
allow_negative_indices.append(
not isinstance(ind, (int, np.integer)) or bool(ind < 0))
# Try to use static slicing when possible.
if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices):
int_start_indices = [int(i) for i in start_indices] # type: ignore
int_limit_indices = [i + s for i, s in zip(int_start_indices, slice_sizes)]
arr = lax.slice(
arr, start_indices=int_start_indices, limit_indices=int_limit_indices)
else:
# We must be careful with dtypes because dynamic_slice requires all
# start indices to have matching types.
if len(start_indices) > 1:
start_indices = util.promote_dtypes(*start_indices)
arr = lax.dynamic_slice(
arr, start_indices=start_indices, slice_sizes=slice_sizes,
allow_negative_indices=allow_negative_indices)
if int_indices:
arr = lax.squeeze(arr, tuple(int_indices))
return arr
def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None, out_sharding=None):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
# For simplicity of generated primitives, we call lax.dynamic_slice in the
# simplest cases: i.e. non-dynamic arrays indexed with integers and slices.
if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None:
return result
# TODO(mattjj,dougalm): expand dynamic shape indexing support
if config.dynamic_shapes.value and arr.ndim > 0:
try: aval = core.get_aval(idx)
except: pass
else:
if (isinstance(aval, core.DShapedArray) and aval.shape == () and
dtypes.issubdtype(aval.dtype, np.integer) and
not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
isinstance(arr.shape[0], int)):
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape)
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices, mode, fill_value, out_sharding)
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(1, 2))
def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices, mode, fill_value, out_sharding):
idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update
y = arr
if fill_value is not None:
core.concrete_or_error(None, fill_value,
"fill_value argument to indexed get()")
if np.ndim(fill_value) != 0:
raise ValueError("fill_value argument to indexed get() must be a scalar")
if isinstance(fill_value, np.ndarray):
fill_value = fill_value.item()
if indexer.scalar_bool_dims:
y = lax.expand_dims(y, indexer.scalar_bool_dims)
# Avoid calling gather if the slice shape is empty, both as a fast path and to
# handle cases like zeros(0)[array([], int32)].
if core.is_empty_shape(indexer.slice_shape):
return lax.full_like(y, 0, shape=indexer.slice_shape)
# We avoid generating a gather when indexer.gather_indices.size is empty.
if not core.is_empty_shape(indexer.gather_indices.shape):
internal_gather = partial(
lax.gather,
dimension_numbers=indexer.dnums,
slice_sizes=indexer.gather_slice_shape,
unique_indices=unique_indices or indexer.unique_indices,
indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted,
mode=mode, fill_value=fill_value)
if out_sharding is not None:
internal_gather = auto_axes(
internal_gather, axes=mesh_lib.get_abstract_mesh().axis_names,
out_shardings=out_sharding)
y = internal_gather(y, indexer.gather_indices)
# Reverses axes with negative strides.
if indexer.reversed_y_dims:
y = lax.rev(y, indexer.reversed_y_dims)
# This adds np.newaxis/None dimensions.
return lax.expand_dims(y, indexer.newaxis_dims)
class _Indexer(NamedTuple):
# The expected shape of the slice output.
slice_shape: Sequence[int]
# The slice shape to pass to lax.gather().
gather_slice_shape: Sequence[int]
# The gather indices to use.
gather_indices: ArrayLike
# A GatherDimensionNumbers object describing the gather to perform.
dnums: lax.GatherDimensionNumbers
# Are the gather_indices known to be non-overlapping and/or sorted?
# (In practice, these translate to "there no advanced indices", because
# only advanced indices could lead to index repetition.)
unique_indices: bool
indices_are_sorted: bool
# Slice dimensions that have negative strides, and so must be reversed after
# the gather.
reversed_y_dims: Sequence[int]
# Keep track of any axes created by `newaxis`. These must be inserted for
# gathers and eliminated for scatters.
newaxis_dims: Sequence[int]
# Keep track of dimensions with scalar bool indices. These must be inserted
# for gathers before performing other index operations.
scalar_bool_dims: Sequence[int]
def split_index_for_jit(idx, shape):
"""Splits indices into necessarily-static and dynamic parts.
Used to pass indices into `jit`-ted function.
"""
# Convert list indices to tuples in cases (deprecated by NumPy.)
idx = eliminate_deprecated_list_indexing(idx)
if any(isinstance(i, str) for i in idx):
raise TypeError(f"JAX does not support string indexing; got {idx=}")
# Expand any (concrete) boolean indices. We can then use advanced integer
# indexing logic to handle them.
idx = _expand_bool_indices(idx, shape)
leaves, treedef = tree_flatten(idx)
dynamic = [None] * len(leaves)
static = [None] * len(leaves)
for i, x in enumerate(leaves):
if x is Ellipsis:
static[i] = x
elif isinstance(x, slice):
# slice objects aren't hashable.
static[i] = (x.start, x.stop, x.step)
else:
dynamic[i] = x
return treedef, tuple(static), dynamic
def merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
"""Recombines indices that were split by split_index_for_jit."""
idx = []
for s, d in zip(static_idx, dynamic_idx):
if d is not None:
idx.append(d)
elif isinstance(s, tuple):
idx.append(slice(s[0], s[1], s[2]))
else:
idx.append(s)
return treedef.unflatten(idx)
def _int(aval):
return not aval.shape and dtypes.issubdtype(aval.dtype, np.integer)
def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
normalize_indices: bool = True) -> _Indexer:
# Check whether advanced indices are contiguous. We must do this before
# removing ellipses (https://github.com/jax-ml/jax/issues/25109)
# If advanced idexing axes do not appear contiguously, NumPy semantics
# move the advanced axes to the front.
is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray))
or lax_numpy.isscalar(e) for e in idx])
advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1)
# Remove ellipses and add trailing slice(None)s.
idx = _canonicalize_tuple_index(len(x_shape), idx)
# Check for scalar boolean indexing: this requires inserting extra dimensions
# before performing the rest of the logic.
scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)]
if scalar_bool_dims:
idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i for i in idx)
x_shape = list(x_shape)
for i in sorted(scalar_bool_dims):
x_shape.insert(i, 1)
x_shape = tuple(x_shape)
# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
advanced_indexes: Sequence[Array | np.ndarray] | None = None
# The positions of the advanced indexing axes in `idx`.
idx_advanced_axes: Sequence[int] = []
# The positions of the advanced indexes in x's shape.
# collapsed, after None axes have been removed. See below.
x_advanced_axes: Sequence[int] | None = None
if _is_advanced_int_indexer(idx):
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
advanced_pairs = (
(lax_numpy.asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
if lax_numpy.isscalar(e) or isinstance(e, (Sequence, Array, np.ndarray)))
if normalize_indices:
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
for e, i, j in advanced_pairs)
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
x_axis = 0 # Current axis in x.
y_axis = 0 # Current axis in y, before collapsing. See below.
collapsed_y_axis = 0 # Current axis in y, after collapsing.
# Scatter dimension numbers.
offset_dims: list[int] = []
collapsed_slice_dims: list[int] = []
start_index_map: list[int] = []
use_64bit_index = (
any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape) and
config.enable_x64.value)
index_dtype = np.dtype('int64') if use_64bit_index else np.dtype('int32')
# Gather indices.
# Pairs of (array, start_dim) values. These will be broadcast into
# gather_indices_shape, with the array dimensions aligned to start_dim, and
# then concatenated.
gather_indices: list[tuple[Array, int]] = []
gather_indices_shape: list[int] = []
# We perform three transformations to y before the scatter op, in order:
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
# the right shape.
slice_shape: list[int] = []
# Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None`
# indices, which the scatter cannot remove itself.
newaxis_dims: list[int] = []
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
reversed_y_dims: list[int] = []
gather_slice_shape: list[int] = []
for idx_pos, i in enumerate(idx):
# Handle the advanced indices here if:
# * the advanced indices were not contiguous and we are the start.
# * we are at the position of the first advanced index.
if (advanced_indexes is not None and
(advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
not advanced_axes_are_contiguous and idx_pos == 0)):
advanced_index_arrs = util._broadcast_arrays(*advanced_indexes)
shape = advanced_index_arrs[0].shape
ndim = len(shape)
start_dim = len(gather_indices_shape)
gather_indices.extend(
(lax.convert_element_type(a, index_dtype), start_dim)
for a in advanced_index_arrs
)
gather_indices_shape += shape
assert x_advanced_axes is not None
start_index_map.extend(x_advanced_axes)
collapsed_slice_dims.extend(x_advanced_axes)
slice_shape.extend(shape)
y_axis += ndim
collapsed_y_axis += ndim
# Per-index bookkeeping for advanced indexes.
if idx_pos in idx_advanced_axes:
x_axis += 1
gather_slice_shape.append(1)
continue
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
# Handle basic int indexes.
if isinstance(abstract_i, core.ShapedArray) and _int(abstract_i):
if core.definitely_equal(x_shape[x_axis], 0):
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
i_converted = lax.convert_element_type(i, index_dtype)
gather_indices.append((i_converted, len(gather_indices_shape)))
collapsed_slice_dims.append(x_axis)
gather_slice_shape.append(1)
start_index_map.append(x_axis)
x_axis += 1
# Handle np.newaxis (None)
elif i is None:
slice_shape.append(1)
newaxis_dims.append(y_axis)
y_axis += 1
elif isinstance(i, slice):
# Handle slice index (only static, otherwise an error is raised)
if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
for elt in (i.start, i.stop, i.step)):
msg = ("Array slice indices must have static start/stop/step to be used "
"with NumPy indexing syntax. "
f"Found slice({i.start}, {i.stop}, {i.step}). "
"To index a statically sized "
"array at a dynamic position, try lax.dynamic_slice/"
"dynamic_update_slice (JAX does not support dynamically sized "
"arrays within JIT compiled functions).")
raise IndexError(msg)
start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis])
slice_shape.append(slice_size)
if core.definitely_equal(step, 1):
# Avoid generating trivial gather (an optimization)
if not core.definitely_equal(slice_size, x_shape[x_axis]):
gather_indices.append((lax.convert_element_type(start, index_dtype),
len(gather_indices_shape)))
start_index_map.append(x_axis)
gather_slice_shape.append(slice_size)
offset_dims.append(collapsed_y_axis)
else:
indices = (lax_numpy.array(start, dtype=index_dtype) +
lax_numpy.array(step, dtype=index_dtype) * lax.iota(index_dtype, slice_size))
if step < 0:
reversed_y_dims.append(collapsed_y_axis)
indices = lax.rev(indices, dimensions=(0,))
gather_slice_shape.append(1)
gather_indices.append((indices, len(gather_indices_shape)))
start_index_map.append(x_axis)
gather_indices_shape.append(slice_size)
collapsed_slice_dims.append(x_axis)
collapsed_y_axis += 1
y_axis += 1
x_axis += 1
else:
if (abstract_i is not None and
not (dtypes.issubdtype(abstract_i.dtype, np.integer) or dtypes.issubdtype(abstract_i.dtype, np.bool_))):
msg = ("Indexer must have integer or boolean type, got indexer "
"with type {} at position {}, indexer value {}")
raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
raise IndexError("Indexing mode not yet supported. Got unsupported indexer "
f"at position {idx_pos}: {i!r}")
if len(gather_indices) == 0:
gather_indices_array: ArrayLike = np.zeros((0,), dtype=index_dtype)
elif len(gather_indices) == 1:
g, _ = gather_indices[0]
gather_indices_array = lax.expand_dims(g, (g.ndim,))
else:
last_dim = len(gather_indices_shape)
gather_indices_shape.append(1)
gather_indices_array = lax.concatenate([
lax.broadcast_in_dim(g, gather_indices_shape, tuple(range(i, i + g.ndim)))
for g, i in gather_indices],
last_dim)
dnums = lax.GatherDimensionNumbers(
offset_dims = tuple(offset_dims),
collapsed_slice_dims = tuple(sorted(collapsed_slice_dims)),
start_index_map = tuple(start_index_map)
)
return _Indexer(
slice_shape=slice_shape,
newaxis_dims=tuple(newaxis_dims),
gather_slice_shape=gather_slice_shape,
reversed_y_dims=reversed_y_dims,
dnums=dnums,
gather_indices=gather_indices_array,
unique_indices=advanced_indexes is None,
indices_are_sorted=advanced_indexes is None,
scalar_bool_dims=scalar_bool_dims)
def _should_unpack_list_index(x):
"""Helper for eliminate_deprecated_list_indexing."""
return (isinstance(x, (np.ndarray, Array)) and np.ndim(x) != 0
or isinstance(x, (Sequence, slice))
or x is Ellipsis or x is None)
def eliminate_deprecated_list_indexing(idx):
# "Basic slicing is initiated if the selection object is a non-array,
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
# objects]". Detects this and raises a TypeError.
if not isinstance(idx, tuple):
if isinstance(idx, Sequence) and not isinstance(idx, (Array, np.ndarray, str)):
# As of numpy 1.16, some non-tuple sequences of indices result in a warning, while
# others are converted to arrays, based on a set of somewhat convoluted heuristics
# (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343)
# In JAX, we raise an informative TypeError for *all* non-tuple sequences.
if any(_should_unpack_list_index(i) for i in idx):
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[tuple(seq)]` instead of `arr[seq]`. "
"See https://github.com/jax-ml/jax/issues/4564 for more information.")
else:
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[array(seq)]` instead of `arr[seq]`. "
"See https://github.com/jax-ml/jax/issues/4564 for more information.")
raise TypeError(msg)
else:
idx = (idx,)
return idx
def _is_boolean_index(i):
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
return (isinstance(abstract_i, core.ShapedArray) and dtypes.issubdtype(abstract_i.dtype, np.bool_)
or isinstance(i, list) and i and all(_is_scalar(e)
and dtypes.issubdtype(dtypes.dtype(e), np.bool_) for e in i))
def _expand_bool_indices(idx, shape):
"""Converts concrete bool indexes into advanced integer indexes."""
out = []
total_dims = len(shape)
num_ellipsis = sum(e is Ellipsis for e in idx)
if num_ellipsis > 1:
raise IndexError("an index can only have a single ellipsis ('...')")
elif num_ellipsis == 1:
total_dims = sum(np.ndim(e) if _is_boolean_index(e) else 1 for e in idx
if e is not None and e is not Ellipsis)
ellipsis_offset = 0
newaxis_offset = 0
for dim_number, i in enumerate(idx):
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
if _is_boolean_index(i):
if isinstance(i, list):
i = lax_numpy.array(i)
abstract_i = core.get_aval(i)
if not core.is_concrete(i):
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(abstract_i)
elif np.ndim(i) == 0:
out.append(bool(i))
else:
i_shape = np.shape(i)
start = len(out) + ellipsis_offset - newaxis_offset
expected_shape = shape[start: start + np.ndim(i)]
if len(i_shape) != len(expected_shape):
raise IndexError(f"too many boolean indices at index {dim_number}: got mask of shape "
f"{i_shape}, but only {len(expected_shape)} dimensions remain.")
if not all(s1 in (0, s2) for s1, s2 in zip(i_shape, expected_shape)):
raise IndexError("boolean index did not match shape of indexed array in index "
f"{dim_number}: got {i_shape}, expected {expected_shape}")
out.extend(np.where(i))
else:
out.append(i)
if i is Ellipsis:
ellipsis_offset = len(shape) - total_dims - 1
if i is None:
newaxis_offset += 1
return tuple(out)
def _is_slice_element_none_or_constant_or_symbolic(elt):
"""Return True if elt is a constant or None."""
if elt is None: return True
if core.is_symbolic_dim(elt): return True
try:
return core.is_concrete(elt)
except TypeError:
return False
# TODO(mattjj): clean up this logic
def _is_advanced_int_indexer(idx):
"""Returns True if idx should trigger int array indexing, False otherwise."""
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
assert isinstance(idx, tuple)
if all(e is None or e is Ellipsis or isinstance(e, slice)
or _is_scalar(e) and dtypes.issubdtype(dtypes.dtype(e), np.integer) for e in idx):
return False
return all(e is None or e is Ellipsis or isinstance(e, slice)
or _is_int_arraylike(e) for e in idx)
def _is_int_arraylike(x):
"""Returns True if x is array-like with integer dtype, False otherwise."""
return (isinstance(x, int) and not isinstance(x, bool)
or dtypes.issubdtype(getattr(x, "dtype", None), np.integer)
or isinstance(x, (list, tuple)) and all(_is_int_arraylike(e) for e in x))
def _is_scalar(x):
"""Checks if a Python or NumPy scalar."""
return np.isscalar(x) or (isinstance(x, (np.ndarray, Array))
and np.ndim(x) == 0)
def _canonicalize_tuple_index(arr_ndim, idx):
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx)
if num_dimensions_consumed > arr_ndim:
index_or_indices = "index" if num_dimensions_consumed == 1 else "indices"
raise IndexError(
f"Too many indices: {arr_ndim}-dimensional array indexed "
f"with {num_dimensions_consumed} regular {index_or_indices}.")
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
ellipsis_index = next(ellipses, None)
if ellipsis_index is not None:
if next(ellipses, None) is not None:
raise IndexError(
f"Multiple ellipses (...) not supported: {list(map(type, idx))}.")
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
elif num_dimensions_consumed < arr_ndim:
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
idx = tuple(idx) + colons
return idx
@export
def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *,
inplace: bool = True) -> Array:
"""Update array elements based on a mask.
JAX implementation of :func:`numpy.place`.
The semantics of :func:`numpy.place` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
arr: array into which values will be placed.
mask: boolean mask with the same size as ``arr``.
vals: values to be inserted into ``arr`` at the locations indicated
by mask. If too many values are supplied, they will be truncated.
If not enough values are supplied, they will be repeated.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
Returns:
A copy of ``arr`` with masked values set to entries from `vals`.
See Also:
- :func:`jax.numpy.put`: put elements into an array at numerical indices.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing
Examples:
>>> x = jnp.zeros((3, 5), dtype=int)
>>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape)
>>> mask
Array([[ True, False, False, True, False],
[False, True, False, False, True],
[False, False, True, False, False]], dtype=bool)
Placing a scalar value:
>>> jnp.place(x, mask, 1, inplace=False)
Array([[1, 0, 0, 1, 0],
[0, 1, 0, 0, 1],
[0, 0, 1, 0, 0]], dtype=int32)
In this case, ``jnp.place`` is similar to the masked array update syntax:
>>> x.at[mask].set(1)
Array([[1, 0, 0, 1, 0],
[0, 1, 0, 0, 1],
[0, 0, 1, 0, 0]], dtype=int32)
``place`` differs when placing values from an array. The array is repeated
to fill the masked entries:
>>> vals = jnp.array([1, 3, 5])
>>> jnp.place(x, mask, vals, inplace=False)
Array([[1, 0, 0, 3, 0],
[0, 5, 0, 0, 1],
[0, 0, 3, 0, 0]], dtype=int32)
"""
data, mask_arr, vals_arr = util.ensure_arraylike("place", arr, mask, vals)
vals_arr = vals_arr.ravel()
if inplace:
raise ValueError(
"jax.numpy.place cannot modify arrays in-place, because JAX arrays are immutable. "
"Pass inplace=False to instead return an updated array.")
if data.size != mask_arr.size:
raise ValueError("place: arr and mask must be the same size")
if not vals_arr.size:
raise ValueError("Cannot place values from an empty array")
if not data.size:
return data
indices = lax_numpy.where(mask_arr.ravel(), size=mask_arr.size, fill_value=mask_arr.size)[0]
vals_arr = lax_numpy._tile_to_size(vals_arr, len(indices))
return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape)
@export
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
mode: str | None = None, *, inplace: bool = True) -> Array:
"""Put elements into an array at given indices.
JAX implementation of :func:`numpy.put`.
The semantics of :func:`numpy.put` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
a: array into which values will be placed.
ind: array of indices over the flattened array at which to put values.
v: array of values to put into the array.
mode: string specifying how to handle out-of-bound indices. Supported values:
- ``"clip"`` (default): clip out-of-bound indices to the final index.
- ``"wrap"``: wrap out-of-bound indices to the beginning of the array.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
Returns:
A copy of ``a`` with specified entries updated.
See Also:
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
- :func:`jax.numpy.take`: extract values from an array at given indices.
Examples:
>>> x = jnp.zeros(5, dtype=int)
>>> indices = jnp.array([0, 2, 4])
>>> values = jnp.array([10, 20, 30])
>>> jnp.put(x, indices, values, inplace=False)
Array([10, 0, 20, 0, 30], dtype=int32)
This is equivalent to the following :attr:`jax.numpy.ndarray.at` indexing syntax:
>>> x.at[indices].set(values)
Array([10, 0, 20, 0, 30], dtype=int32)
There are two modes for handling out-of-bound indices. By default they are
clipped:
>>> indices = jnp.array([0, 2, 6])
>>> jnp.put(x, indices, values, inplace=False, mode='clip')
Array([10, 0, 20, 0, 30], dtype=int32)
Alternatively, they can be wrapped to the beginning of the array:
>>> jnp.put(x, indices, values, inplace=False, mode='wrap')
Array([10, 30, 20, 0, 0], dtype=int32)
For N-dimensional inputs, the indices refer to the flattened array:
>>> x = jnp.zeros((3, 5), dtype=int)
>>> indices = jnp.array([0, 7, 14])
>>> jnp.put(x, indices, values, inplace=False)
Array([[10, 0, 0, 0, 0],
[ 0, 0, 20, 0, 0],
[ 0, 0, 0, 0, 30]], dtype=int32)
"""
arr, ind_arr, _ = util.ensure_arraylike("put", a, ind, v)
ind_arr = ind_arr.ravel()
v_arr = lax_numpy.ravel(v)
if not arr.size or not ind_arr.size or not v_arr.size:
return arr
v_arr = lax_numpy._tile_to_size(v_arr, len(ind_arr))
if inplace:
raise ValueError(
"jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. "
"Pass inplace=False to instead return an updated array.")
if mode is None:
scatter_mode = "drop"
elif mode == "clip":
ind_arr = lax_numpy.clip(ind_arr, 0, arr.size - 1)
scatter_mode = "promise_in_bounds"
elif mode == "wrap":
ind_arr = ind_arr % arr.size
scatter_mode = "promise_in_bounds"
elif mode == "raise":
raise NotImplementedError("The 'raise' mode to jnp.put is not supported.")
else:
raise ValueError(f"mode should be one of 'wrap' or 'clip'; got {mode=}")
return arr.at[lax_numpy.unravel_index(ind_arr, arr.shape)].set(v_arr, mode=scatter_mode)