mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add an allow_negative_indices option to lax.dynamic_slice and lax.dynamic_update_slice.
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument. PiperOrigin-RevId: 731812827
This commit is contained in:
parent
3450e2cee0
commit
1e5d9a9158
@ -16,6 +16,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## Unreleased
|
||||
|
||||
* New Features
|
||||
|
||||
* Added a `allow_negative_indices` option to {func}`jax.lax.dynamic_slice`,
|
||||
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
|
||||
true, matching the current behavior. If set to false, JAX does not need to
|
||||
emit code clamping negative indices, which improves code size.
|
||||
|
||||
## jax 0.5.1 (Feb 24, 2025)
|
||||
|
||||
* New Features
|
||||
|
@ -465,10 +465,19 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
||||
def body_fun(while_carry):
|
||||
i_, carry, yss = while_carry
|
||||
i = num_trips - i_ - 1 if reverse else i_
|
||||
xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False) for xs in xss]
|
||||
xs = [
|
||||
slicing.dynamic_index_in_dim(
|
||||
xs, i, keepdims=False, allow_negative_indices=False
|
||||
)
|
||||
for xs in xss
|
||||
]
|
||||
carry, ys = inner(unroll, carry, xs)
|
||||
yss = [slicing.dynamic_update_index_in_dim(ys, upd, i, 0)
|
||||
for ys, upd in zip(yss, ys)]
|
||||
yss = [
|
||||
slicing.dynamic_update_index_in_dim(
|
||||
ys, upd, i, 0, allow_negative_indices=False
|
||||
)
|
||||
for ys, upd in zip(yss, ys)
|
||||
]
|
||||
return i_ + 1, carry, yss
|
||||
def inner(n, carry, xs):
|
||||
ys = []
|
||||
|
@ -114,6 +114,8 @@ def dynamic_slice(
|
||||
operand: Array | np.ndarray,
|
||||
start_indices: Array | np.ndarray | Sequence[ArrayLike],
|
||||
slice_sizes: Shape,
|
||||
*,
|
||||
allow_negative_indices: bool | Sequence[bool] = True
|
||||
) -> Array:
|
||||
"""Wraps XLA's `DynamicSlice
|
||||
<https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
|
||||
@ -127,6 +129,12 @@ def dynamic_slice(
|
||||
integers with length equal to `ndim(operand)`. Inside a JIT compiled
|
||||
function, only static values are supported (all JAX arrays inside JIT
|
||||
must have statically known size).
|
||||
allow_negative_indices: a bool or sequence of bools, one per dimension; if
|
||||
a bool is passed, it applies to all dimensions. For each dimension,
|
||||
if true, negative indices are permitted and are are interpreted relative
|
||||
to the end of the array. If false, negative indices are treated as if they
|
||||
were out of bounds and the result is implementation defined, typically
|
||||
clamped to the first index.
|
||||
|
||||
Returns:
|
||||
An array containing the slice.
|
||||
@ -158,7 +166,8 @@ def dynamic_slice(
|
||||
- :func:`jax.lax.dynamic_slice_in_dim`
|
||||
- :func:`jax.lax.dynamic_index_in_dim`
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
start_indices = _dynamic_slice_indices(
|
||||
operand, start_indices, allow_negative_indices)
|
||||
if config.dynamic_shapes.value:
|
||||
dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes)
|
||||
else:
|
||||
@ -168,8 +177,12 @@ def dynamic_slice(
|
||||
slice_sizes=tuple(static_sizes))
|
||||
|
||||
|
||||
def dynamic_update_slice(operand: Array | np.ndarray, update: ArrayLike,
|
||||
start_indices: Array | Sequence[ArrayLike]) -> Array:
|
||||
def dynamic_update_slice(
|
||||
operand: Array | np.ndarray, update: ArrayLike,
|
||||
start_indices: Array | Sequence[ArrayLike],
|
||||
*,
|
||||
allow_negative_indices: bool | Sequence[bool] = True
|
||||
) -> Array:
|
||||
"""Wraps XLA's `DynamicUpdateSlice
|
||||
<https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
|
||||
operator.
|
||||
@ -178,6 +191,12 @@ def dynamic_update_slice(operand: Array | np.ndarray, update: ArrayLike,
|
||||
operand: an array to slice.
|
||||
update: an array containing the new values to write onto `operand`.
|
||||
start_indices: a list of scalar indices, one per dimension.
|
||||
allow_negative_indices: a bool or sequence of bools, one per dimension; if
|
||||
a bool is passed, it applies to all dimensions. For each dimension,
|
||||
if true, negative indices are permitted and are are interpreted relative
|
||||
to the end of the array. If false, negative indices are treated as if they
|
||||
were out of bounds and the result is implementation defined, typically
|
||||
clamped to the first index.
|
||||
|
||||
Returns:
|
||||
An array containing the slice.
|
||||
@ -213,7 +232,8 @@ def dynamic_update_slice(operand: Array | np.ndarray, update: ArrayLike,
|
||||
- :attr:`lax.dynamic_update_index_in_dim`
|
||||
- :attr:`lax.dynamic_update_slice_in_dim`
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
start_indices = _dynamic_slice_indices(
|
||||
operand, start_indices, allow_negative_indices)
|
||||
return dynamic_update_slice_p.bind(operand, update, *start_indices)
|
||||
|
||||
|
||||
@ -997,7 +1017,8 @@ def index_in_dim(operand: Array | np.ndarray, index: int, axis: int = 0,
|
||||
|
||||
def dynamic_slice_in_dim(operand: Array | np.ndarray,
|
||||
start_index: ArrayLike,
|
||||
slice_size: int, axis: int = 0) -> Array:
|
||||
slice_size: int, axis: int = 0, *,
|
||||
allow_negative_indices: bool = True) -> Array:
|
||||
"""Convenience wrapper around :func:`lax.dynamic_slice` applied to one dimension.
|
||||
|
||||
This is roughly equivalent to the following Python indexing syntax applied
|
||||
@ -1008,6 +1029,10 @@ def dynamic_slice_in_dim(operand: Array | np.ndarray,
|
||||
start_index: the (possibly dynamic) start index
|
||||
slice_size: the static slice size
|
||||
axis: the axis along which to apply the slice (defaults to 0)
|
||||
allow_negative_indices: boolean specifying whether negative indices are
|
||||
allowed. If true, negative indices are taken relative to the end of the
|
||||
array. If false, negative indices are out of bounds and the result is
|
||||
implementation defined.
|
||||
|
||||
Returns:
|
||||
An array containing the slice.
|
||||
@ -1046,16 +1071,20 @@ def dynamic_slice_in_dim(operand: Array | np.ndarray,
|
||||
"""
|
||||
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
|
||||
slice_sizes = list(operand.shape)
|
||||
|
||||
axis = int(axis)
|
||||
allow_negative = [False] * operand.ndim
|
||||
allow_negative[axis] = allow_negative_indices
|
||||
start_indices[axis] = start_index
|
||||
slice_sizes[axis] = core._canonicalize_dimension(slice_size)
|
||||
return dynamic_slice(operand, start_indices, slice_sizes)
|
||||
return dynamic_slice(operand, start_indices, slice_sizes,
|
||||
allow_negative_indices=allow_negative)
|
||||
|
||||
|
||||
def dynamic_index_in_dim(operand: Array | np.ndarray,
|
||||
index: int | Array,
|
||||
axis: int = 0, keepdims: bool = True) -> Array:
|
||||
axis: int = 0, keepdims: bool = True,
|
||||
*,
|
||||
allow_negative_indices: bool = True) -> Array:
|
||||
"""Convenience wrapper around dynamic_slice to perform int indexing.
|
||||
|
||||
This is roughly equivalent to the following Python indexing syntax applied
|
||||
@ -1067,6 +1096,10 @@ def dynamic_index_in_dim(operand: Array | np.ndarray,
|
||||
axis: the axis along which to apply the slice (defaults to 0)
|
||||
keepdims: boolean specifying whether the output should have the same rank as
|
||||
the input (default = True)
|
||||
allow_negative_indices: boolean specifying whether negative indices are
|
||||
allowed. If true, negative indices are taken relative to the end of the
|
||||
array. If false, negative indices are out of bounds and the result is
|
||||
implementation defined.
|
||||
|
||||
Returns:
|
||||
An array containing the slice.
|
||||
@ -1098,7 +1131,8 @@ def dynamic_index_in_dim(operand: Array | np.ndarray,
|
||||
- :func:`jax.lax.dynamic_slice`
|
||||
- :func:`jax.lax.dynamic_slice_in_dim`
|
||||
"""
|
||||
result = dynamic_slice_in_dim(operand, index, 1, axis)
|
||||
result = dynamic_slice_in_dim(operand, index, 1, axis,
|
||||
allow_negative_indices=allow_negative_indices)
|
||||
if keepdims:
|
||||
return result
|
||||
else:
|
||||
@ -1107,7 +1141,9 @@ def dynamic_index_in_dim(operand: Array | np.ndarray,
|
||||
|
||||
def dynamic_update_slice_in_dim(operand: Array | np.ndarray,
|
||||
update: ArrayLike,
|
||||
start_index: ArrayLike, axis: int) -> Array:
|
||||
start_index: ArrayLike, axis: int,
|
||||
*,
|
||||
allow_negative_indices: bool = True) -> Array:
|
||||
"""Convenience wrapper around :func:`dynamic_update_slice` to update
|
||||
a slice in a single ``axis``.
|
||||
|
||||
@ -1116,6 +1152,10 @@ def dynamic_update_slice_in_dim(operand: Array | np.ndarray,
|
||||
update: an array containing the new values to write onto `operand`.
|
||||
start_index: a single scalar index
|
||||
axis: the axis of the update.
|
||||
allow_negative_indices: boolean specifying whether negative indices are
|
||||
allowed. If true, negative indices are taken relative to the end of the
|
||||
array. If false, negative indices are out of bounds and the result is
|
||||
implementation defined.
|
||||
|
||||
Returns:
|
||||
The updated array
|
||||
@ -1164,12 +1204,16 @@ def dynamic_update_slice_in_dim(operand: Array | np.ndarray,
|
||||
axis = int(axis)
|
||||
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
|
||||
start_indices[axis] = start_index
|
||||
return dynamic_update_slice(operand, update, start_indices)
|
||||
allow_negative = [False] * operand.ndim
|
||||
allow_negative[axis] = allow_negative_indices
|
||||
return dynamic_update_slice(operand, update, start_indices,
|
||||
allow_negative_indices=allow_negative)
|
||||
|
||||
|
||||
def dynamic_update_index_in_dim(operand: Array | np.ndarray,
|
||||
update: ArrayLike, index: ArrayLike,
|
||||
axis: int) -> Array:
|
||||
axis: int, *,
|
||||
allow_negative_indices: bool = True) -> Array:
|
||||
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
|
||||
of size 1 in a single ``axis``.
|
||||
|
||||
@ -1178,6 +1222,10 @@ def dynamic_update_index_in_dim(operand: Array | np.ndarray,
|
||||
update: an array containing the new values to write onto `operand`.
|
||||
index: a single scalar index
|
||||
axis: the axis of the update.
|
||||
allow_negative_indices: boolean specifying whether negative indices are
|
||||
allowed. If true, negative indices are taken relative to the end of the
|
||||
array. If false, negative indices are out of bounds and the result is
|
||||
implementation defined.
|
||||
|
||||
Returns:
|
||||
The updated array
|
||||
@ -1229,7 +1277,9 @@ def dynamic_update_index_in_dim(operand: Array | np.ndarray,
|
||||
if lax._ndim(update) != lax._ndim(operand):
|
||||
assert lax._ndim(update) + 1 == lax._ndim(operand)
|
||||
update = lax.expand_dims(update, (axis,))
|
||||
return dynamic_update_slice_in_dim(operand, update, index, axis)
|
||||
return dynamic_update_slice_in_dim(
|
||||
operand, update, index, axis,
|
||||
allow_negative_indices=allow_negative_indices)
|
||||
|
||||
|
||||
def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
|
||||
@ -3038,7 +3088,8 @@ mlir.register_lowering(
|
||||
|
||||
def _dynamic_slice_indices(
|
||||
operand: Array | np.ndarray,
|
||||
start_indices: Array | np.ndarray | Sequence[ArrayLike]
|
||||
start_indices: Array | np.ndarray | Sequence[ArrayLike],
|
||||
allow_negative_indices: bool | Sequence[bool],
|
||||
) -> list[ArrayLike]:
|
||||
# Normalize the start_indices w.r.t. operand.shape
|
||||
if len(start_indices) != operand.ndim:
|
||||
@ -3051,20 +3102,39 @@ def _dynamic_slice_indices(
|
||||
.format(start_indices.shape)) # type: ignore[union-attr]
|
||||
start_indices = list(start_indices)
|
||||
result: list[ArrayLike] = []
|
||||
if isinstance(allow_negative_indices, bool):
|
||||
allow_negative_indices = [allow_negative_indices] * len(start_indices)
|
||||
# Loop to correct for negative indices.
|
||||
for i, d in zip(start_indices, operand.shape):
|
||||
for i, d, allow_negative_index in zip(
|
||||
start_indices, operand.shape, allow_negative_indices
|
||||
):
|
||||
# If i is unsigned, then it cannot be negative.
|
||||
if dtypes.issubdtype(_dtype(i), np.unsignedinteger):
|
||||
result.append(i)
|
||||
continue
|
||||
# Test whether i and d are static to avoid unnecessary staging.
|
||||
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
|
||||
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
|
||||
if allow_negative_index:
|
||||
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
|
||||
elif i < 0:
|
||||
raise ValueError(f"Index {i} is out of bounds for dimension {d} if "
|
||||
"allow_negative_indices=False")
|
||||
else:
|
||||
result.append(lax.convert_element_type(i, _dtype(i)))
|
||||
continue
|
||||
d = core.dimension_as_value(d)
|
||||
if isinstance(i, (int, np.integer)):
|
||||
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0 else i)
|
||||
if allow_negative_index:
|
||||
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0 else i)
|
||||
elif i < 0:
|
||||
raise ValueError(f"Index {i} is out of bounds for dimension {d} if "
|
||||
"allow_negative_indices=False")
|
||||
else:
|
||||
result.append(i)
|
||||
continue
|
||||
d_arr = lax.convert_element_type(d, _dtype(i))
|
||||
result.append(lax.select(i < 0, i + d_arr, i))
|
||||
if allow_negative_index:
|
||||
d_arr = lax.convert_element_type(d, _dtype(i))
|
||||
result.append(lax.select(i < 0, i + d_arr, i))
|
||||
else:
|
||||
result.append(i)
|
||||
return result
|
||||
|
@ -571,6 +571,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) ->
|
||||
idx += (arr.ndim - len(idx)) * (slice(None),)
|
||||
start_indices: Sequence[ArrayLike] = []
|
||||
slice_sizes: Sequence[int] = []
|
||||
allow_negative_indices: list[bool] = []
|
||||
|
||||
for ind, size in safe_zip(idx, arr.shape):
|
||||
if isinstance(ind, slice):
|
||||
@ -578,11 +579,14 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) ->
|
||||
assert step == 1 # checked above
|
||||
start_indices.append(start)
|
||||
slice_sizes.append(max(0, stop - start))
|
||||
allow_negative_indices.append(start < 0 or stop < 0)
|
||||
else:
|
||||
assert np.issubdtype(dtypes.dtype(ind), np.integer) # checked above
|
||||
assert np.shape(ind) == () # checked above
|
||||
start_indices.append(ind)
|
||||
slice_sizes.append(1)
|
||||
allow_negative_indices.append(
|
||||
not isinstance(ind, (int, np.integer)) or bool(ind < 0))
|
||||
# Try to use static slicing when possible.
|
||||
if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices):
|
||||
int_start_indices = [int(i) for i in start_indices] # type: ignore
|
||||
@ -595,7 +599,8 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) ->
|
||||
if len(start_indices) > 1:
|
||||
start_indices = util.promote_dtypes(*start_indices)
|
||||
arr = lax.dynamic_slice(
|
||||
arr, start_indices=start_indices, slice_sizes=slice_sizes)
|
||||
arr, start_indices=start_indices, slice_sizes=slice_sizes,
|
||||
allow_negative_indices=allow_negative_indices)
|
||||
if int_indices:
|
||||
arr = lax.squeeze(arr, tuple(int_indices))
|
||||
return arr
|
||||
|
Loading…
x
Reference in New Issue
Block a user