Add an allow_negative_indices option to lax.dynamic_slice and lax.dynamic_update_slice.

The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.

PiperOrigin-RevId: 731812827
This commit is contained in:
Peter Hawkins 2025-02-27 12:02:13 -08:00 committed by jax authors
parent 3450e2cee0
commit 1e5d9a9158
4 changed files with 114 additions and 23 deletions

View File

@ -16,6 +16,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## Unreleased
* New Features
* Added a `allow_negative_indices` option to {func}`jax.lax.dynamic_slice`,
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size.
## jax 0.5.1 (Feb 24, 2025)
* New Features

View File

@ -465,10 +465,19 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
def body_fun(while_carry):
i_, carry, yss = while_carry
i = num_trips - i_ - 1 if reverse else i_
xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False) for xs in xss]
xs = [
slicing.dynamic_index_in_dim(
xs, i, keepdims=False, allow_negative_indices=False
)
for xs in xss
]
carry, ys = inner(unroll, carry, xs)
yss = [slicing.dynamic_update_index_in_dim(ys, upd, i, 0)
for ys, upd in zip(yss, ys)]
yss = [
slicing.dynamic_update_index_in_dim(
ys, upd, i, 0, allow_negative_indices=False
)
for ys, upd in zip(yss, ys)
]
return i_ + 1, carry, yss
def inner(n, carry, xs):
ys = []

View File

@ -114,6 +114,8 @@ def dynamic_slice(
operand: Array | np.ndarray,
start_indices: Array | np.ndarray | Sequence[ArrayLike],
slice_sizes: Shape,
*,
allow_negative_indices: bool | Sequence[bool] = True
) -> Array:
"""Wraps XLA's `DynamicSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
@ -127,6 +129,12 @@ def dynamic_slice(
integers with length equal to `ndim(operand)`. Inside a JIT compiled
function, only static values are supported (all JAX arrays inside JIT
must have statically known size).
allow_negative_indices: a bool or sequence of bools, one per dimension; if
a bool is passed, it applies to all dimensions. For each dimension,
if true, negative indices are permitted and are are interpreted relative
to the end of the array. If false, negative indices are treated as if they
were out of bounds and the result is implementation defined, typically
clamped to the first index.
Returns:
An array containing the slice.
@ -158,7 +166,8 @@ def dynamic_slice(
- :func:`jax.lax.dynamic_slice_in_dim`
- :func:`jax.lax.dynamic_index_in_dim`
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
start_indices = _dynamic_slice_indices(
operand, start_indices, allow_negative_indices)
if config.dynamic_shapes.value:
dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes)
else:
@ -168,8 +177,12 @@ def dynamic_slice(
slice_sizes=tuple(static_sizes))
def dynamic_update_slice(operand: Array | np.ndarray, update: ArrayLike,
start_indices: Array | Sequence[ArrayLike]) -> Array:
def dynamic_update_slice(
operand: Array | np.ndarray, update: ArrayLike,
start_indices: Array | Sequence[ArrayLike],
*,
allow_negative_indices: bool | Sequence[bool] = True
) -> Array:
"""Wraps XLA's `DynamicUpdateSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
operator.
@ -178,6 +191,12 @@ def dynamic_update_slice(operand: Array | np.ndarray, update: ArrayLike,
operand: an array to slice.
update: an array containing the new values to write onto `operand`.
start_indices: a list of scalar indices, one per dimension.
allow_negative_indices: a bool or sequence of bools, one per dimension; if
a bool is passed, it applies to all dimensions. For each dimension,
if true, negative indices are permitted and are are interpreted relative
to the end of the array. If false, negative indices are treated as if they
were out of bounds and the result is implementation defined, typically
clamped to the first index.
Returns:
An array containing the slice.
@ -213,7 +232,8 @@ def dynamic_update_slice(operand: Array | np.ndarray, update: ArrayLike,
- :attr:`lax.dynamic_update_index_in_dim`
- :attr:`lax.dynamic_update_slice_in_dim`
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
start_indices = _dynamic_slice_indices(
operand, start_indices, allow_negative_indices)
return dynamic_update_slice_p.bind(operand, update, *start_indices)
@ -997,7 +1017,8 @@ def index_in_dim(operand: Array | np.ndarray, index: int, axis: int = 0,
def dynamic_slice_in_dim(operand: Array | np.ndarray,
start_index: ArrayLike,
slice_size: int, axis: int = 0) -> Array:
slice_size: int, axis: int = 0, *,
allow_negative_indices: bool = True) -> Array:
"""Convenience wrapper around :func:`lax.dynamic_slice` applied to one dimension.
This is roughly equivalent to the following Python indexing syntax applied
@ -1008,6 +1029,10 @@ def dynamic_slice_in_dim(operand: Array | np.ndarray,
start_index: the (possibly dynamic) start index
slice_size: the static slice size
axis: the axis along which to apply the slice (defaults to 0)
allow_negative_indices: boolean specifying whether negative indices are
allowed. If true, negative indices are taken relative to the end of the
array. If false, negative indices are out of bounds and the result is
implementation defined.
Returns:
An array containing the slice.
@ -1046,16 +1071,20 @@ def dynamic_slice_in_dim(operand: Array | np.ndarray,
"""
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
slice_sizes = list(operand.shape)
axis = int(axis)
allow_negative = [False] * operand.ndim
allow_negative[axis] = allow_negative_indices
start_indices[axis] = start_index
slice_sizes[axis] = core._canonicalize_dimension(slice_size)
return dynamic_slice(operand, start_indices, slice_sizes)
return dynamic_slice(operand, start_indices, slice_sizes,
allow_negative_indices=allow_negative)
def dynamic_index_in_dim(operand: Array | np.ndarray,
index: int | Array,
axis: int = 0, keepdims: bool = True) -> Array:
axis: int = 0, keepdims: bool = True,
*,
allow_negative_indices: bool = True) -> Array:
"""Convenience wrapper around dynamic_slice to perform int indexing.
This is roughly equivalent to the following Python indexing syntax applied
@ -1067,6 +1096,10 @@ def dynamic_index_in_dim(operand: Array | np.ndarray,
axis: the axis along which to apply the slice (defaults to 0)
keepdims: boolean specifying whether the output should have the same rank as
the input (default = True)
allow_negative_indices: boolean specifying whether negative indices are
allowed. If true, negative indices are taken relative to the end of the
array. If false, negative indices are out of bounds and the result is
implementation defined.
Returns:
An array containing the slice.
@ -1098,7 +1131,8 @@ def dynamic_index_in_dim(operand: Array | np.ndarray,
- :func:`jax.lax.dynamic_slice`
- :func:`jax.lax.dynamic_slice_in_dim`
"""
result = dynamic_slice_in_dim(operand, index, 1, axis)
result = dynamic_slice_in_dim(operand, index, 1, axis,
allow_negative_indices=allow_negative_indices)
if keepdims:
return result
else:
@ -1107,7 +1141,9 @@ def dynamic_index_in_dim(operand: Array | np.ndarray,
def dynamic_update_slice_in_dim(operand: Array | np.ndarray,
update: ArrayLike,
start_index: ArrayLike, axis: int) -> Array:
start_index: ArrayLike, axis: int,
*,
allow_negative_indices: bool = True) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update
a slice in a single ``axis``.
@ -1116,6 +1152,10 @@ def dynamic_update_slice_in_dim(operand: Array | np.ndarray,
update: an array containing the new values to write onto `operand`.
start_index: a single scalar index
axis: the axis of the update.
allow_negative_indices: boolean specifying whether negative indices are
allowed. If true, negative indices are taken relative to the end of the
array. If false, negative indices are out of bounds and the result is
implementation defined.
Returns:
The updated array
@ -1164,12 +1204,16 @@ def dynamic_update_slice_in_dim(operand: Array | np.ndarray,
axis = int(axis)
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
start_indices[axis] = start_index
return dynamic_update_slice(operand, update, start_indices)
allow_negative = [False] * operand.ndim
allow_negative[axis] = allow_negative_indices
return dynamic_update_slice(operand, update, start_indices,
allow_negative_indices=allow_negative)
def dynamic_update_index_in_dim(operand: Array | np.ndarray,
update: ArrayLike, index: ArrayLike,
axis: int) -> Array:
axis: int, *,
allow_negative_indices: bool = True) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
of size 1 in a single ``axis``.
@ -1178,6 +1222,10 @@ def dynamic_update_index_in_dim(operand: Array | np.ndarray,
update: an array containing the new values to write onto `operand`.
index: a single scalar index
axis: the axis of the update.
allow_negative_indices: boolean specifying whether negative indices are
allowed. If true, negative indices are taken relative to the end of the
array. If false, negative indices are out of bounds and the result is
implementation defined.
Returns:
The updated array
@ -1229,7 +1277,9 @@ def dynamic_update_index_in_dim(operand: Array | np.ndarray,
if lax._ndim(update) != lax._ndim(operand):
assert lax._ndim(update) + 1 == lax._ndim(operand)
update = lax.expand_dims(update, (axis,))
return dynamic_update_slice_in_dim(operand, update, index, axis)
return dynamic_update_slice_in_dim(
operand, update, index, axis,
allow_negative_indices=allow_negative_indices)
def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
@ -3038,7 +3088,8 @@ mlir.register_lowering(
def _dynamic_slice_indices(
operand: Array | np.ndarray,
start_indices: Array | np.ndarray | Sequence[ArrayLike]
start_indices: Array | np.ndarray | Sequence[ArrayLike],
allow_negative_indices: bool | Sequence[bool],
) -> list[ArrayLike]:
# Normalize the start_indices w.r.t. operand.shape
if len(start_indices) != operand.ndim:
@ -3051,20 +3102,39 @@ def _dynamic_slice_indices(
.format(start_indices.shape)) # type: ignore[union-attr]
start_indices = list(start_indices)
result: list[ArrayLike] = []
if isinstance(allow_negative_indices, bool):
allow_negative_indices = [allow_negative_indices] * len(start_indices)
# Loop to correct for negative indices.
for i, d in zip(start_indices, operand.shape):
for i, d, allow_negative_index in zip(
start_indices, operand.shape, allow_negative_indices
):
# If i is unsigned, then it cannot be negative.
if dtypes.issubdtype(_dtype(i), np.unsignedinteger):
result.append(i)
continue
# Test whether i and d are static to avoid unnecessary staging.
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
if allow_negative_index:
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
elif i < 0:
raise ValueError(f"Index {i} is out of bounds for dimension {d} if "
"allow_negative_indices=False")
else:
result.append(lax.convert_element_type(i, _dtype(i)))
continue
d = core.dimension_as_value(d)
if isinstance(i, (int, np.integer)):
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0 else i)
if allow_negative_index:
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0 else i)
elif i < 0:
raise ValueError(f"Index {i} is out of bounds for dimension {d} if "
"allow_negative_indices=False")
else:
result.append(i)
continue
d_arr = lax.convert_element_type(d, _dtype(i))
result.append(lax.select(i < 0, i + d_arr, i))
if allow_negative_index:
d_arr = lax.convert_element_type(d, _dtype(i))
result.append(lax.select(i < 0, i + d_arr, i))
else:
result.append(i)
return result

View File

@ -571,6 +571,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) ->
idx += (arr.ndim - len(idx)) * (slice(None),)
start_indices: Sequence[ArrayLike] = []
slice_sizes: Sequence[int] = []
allow_negative_indices: list[bool] = []
for ind, size in safe_zip(idx, arr.shape):
if isinstance(ind, slice):
@ -578,11 +579,14 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) ->
assert step == 1 # checked above
start_indices.append(start)
slice_sizes.append(max(0, stop - start))
allow_negative_indices.append(start < 0 or stop < 0)
else:
assert np.issubdtype(dtypes.dtype(ind), np.integer) # checked above
assert np.shape(ind) == () # checked above
start_indices.append(ind)
slice_sizes.append(1)
allow_negative_indices.append(
not isinstance(ind, (int, np.integer)) or bool(ind < 0))
# Try to use static slicing when possible.
if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices):
int_start_indices = [int(i) for i in start_indices] # type: ignore
@ -595,7 +599,8 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) ->
if len(start_indices) > 1:
start_indices = util.promote_dtypes(*start_indices)
arr = lax.dynamic_slice(
arr, start_indices=start_indices, slice_sizes=slice_sizes)
arr, start_indices=start_indices, slice_sizes=slice_sizes,
allow_negative_indices=allow_negative_indices)
if int_indices:
arr = lax.squeeze(arr, tuple(int_indices))
return arr