Merge pull request #16642 from jakevdp:slice-in-dim

PiperOrigin-RevId: 546044892
This commit is contained in:
jax authors 2023-07-06 11:27:07 -07:00
commit f08e52faef

View File

@ -54,6 +54,50 @@ def slice(operand: ArrayLike, start_indices: Sequence[int],
"""Wraps XLA's `Slice
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
operator.
Args:
operand: an array to slice
start_indices: a sequence of ``operand.ndim`` start indices.
limit_indices: a sequence of ``operand.ndim`` limit indices.
strides: an optional sequence of ``operand.ndim`` strides.
Returns:
The sliced array
Examples:
Here are some examples of simple two-dimensional slices:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> lax.slice(x, (1, 0), (3, 2))
Array([[4, 5],
[8, 9]], dtype=int32)
>>> lax.slice(x, (0, 0), (3, 4), (1, 2))
Array([[ 0, 2],
[ 4, 6],
[ 8, 10]], dtype=int32)
These two examples are equivalent to the following Python slicing syntax:
>>> x[1:3, 0:2]
Array([[4, 5],
[8, 9]], dtype=int32)
>>> x[0:3, 0:4:2]
Array([[ 0, 2],
[ 4, 6],
[ 8, 10]], dtype=int32)
See Also:
- :attr:`jax.numpy.ndarray.at`
- :func:`jax.lax.slice_in_dim`
- :func:`jax.lax.index_in_dim`
- :func:`jax.lax.dynamic_slice`
"""
return slice_p.bind(operand, start_indices=tuple(start_indices),
limit_indices=tuple(limit_indices),
@ -101,6 +145,12 @@ def dynamic_slice(
>>> dynamic_slice(x, (1, 1), (2, 4))
Array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
See Also:
- :attr:`jax.numpy.ndarray.at`
- :func:`jax.lax.slice`
- :func:`jax.lax.dynamic_slice_in_dim`
- :func:`jax.lax.dynamic_index_in_dim`
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
if jax.config.jax_dynamic_shapes:
@ -151,6 +201,11 @@ def dynamic_update_slice(operand: Union[Array, np.ndarray], update: ArrayLike,
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 0.]], dtype=float32)
See Also:
- :attr:`jax.numpy.ndarray.at`
- :attr:`lax.dynamic_update_index_in_dim`
- :attr:`lax.dynamic_update_slice_in_dim`
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
return dynamic_update_slice_p.bind(operand, update, *start_indices)
@ -646,7 +701,53 @@ def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array:
def slice_in_dim(operand: Union[Array, np.ndarray], start_index: Optional[int],
limit_index: Optional[int],
stride: int = 1, axis: int = 0) -> Array:
"""Convenience wrapper around slice applying to only one dimension."""
"""Convenience wrapper around :func:`lax.slice` applying to only one dimension.
This is effectively equivalent to ``operand[..., start_index:limit_index:stride]``
with the indexing applied on the specified axis.
Args:
operand: an array to slice.
start_index: an optional start index (defaults to zero)
limit_index: an optional end index (defaults to operand.shape[axis])
stride: an optional stride (defaults to 1)
axis: the axis along which to apply the slice (defaults to 0)
Returns:
An array containing the slice.
Examples:
Here is a one-dimensional example:
>>> x = jnp.arange(4)
>>> lax.slice_in_dim(x, 1, 3)
Array([1, 2], dtype=int32)
Here are some two-dimensional examples:
>>> x = jnp.arange(12).reshape(4, 3)
>>> x
Array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3)
Array([[3, 4, 5],
[6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1)
Array([[ 1, 2],
[ 4, 5],
[ 7, 8],
[10, 11]], dtype=int32)
See Also:
- :attr:`jax.numpy.ndarray.at`
- :func:`jax.lax.slice`
- :func:`jax.lax.index_in_dim`
- :func:`jax.lax.dynamic_slice_in_dim`
"""
start_indices = [0] * operand.ndim
limit_indices = list(operand.shape)
strides = [1] * operand.ndim
@ -674,7 +775,51 @@ def slice_in_dim(operand: Union[Array, np.ndarray], start_index: Optional[int],
def index_in_dim(operand: Union[Array, np.ndarray], index: int, axis: int = 0,
keepdims: bool = True) -> Array:
"""Convenience wrapper around slice to perform int indexing."""
"""Convenience wrapper around :func:`lax.slice` to perform int indexing.
This is effectively equivalent to ``operand[..., start_index:limit_index:stride]``
with the indexing applied on the specified axis.
Args:
operand: an array to index.
index: integer index
axis: the axis along which to apply the index (defaults to 0)
keepdims: boolean specifying whether the output array should preserve the
rank of the input (default=True)
Returns:
The subarray at the specified index.
Examples:
Here is a one-dimensional example:
>>> x = jnp.arange(4)
>>> lax.index_in_dim(x, 2)
Array([2], dtype=int32)
>>> lax.index_in_dim(x, 2, keepdims=False)
Array(2, dtype=int32)
Here are some two-dimensional examples:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> lax.index_in_dim(x, 1)
Array([[4, 5, 6, 7]], dtype=int32)
>>> lax.index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)
See Also:
- :attr:`jax.numpy.ndarray.at`
- :func:`jax.lax.slice`
- :func:`jax.lax.slice_in_dim`
- :func:`jax.lax.dynamic_index_in_dim`
"""
index, axis = core._canonicalize_dimension(index), int(axis)
axis_size = operand.shape[axis]
wrapped_index = index + axis_size if index < 0 else index
@ -691,7 +836,52 @@ def index_in_dim(operand: Union[Array, np.ndarray], index: int, axis: int = 0,
def dynamic_slice_in_dim(operand: Union[Array, np.ndarray],
start_index: ArrayLike,
slice_size: int, axis: int = 0) -> Array:
"""Convenience wrapper around dynamic_slice applying to one dimension."""
"""Convenience wrapper around :func:`lax.dynamic_slice` applied to one dimension.
This is roughly equivalent to the following Python indexing syntax applied
along the specified axis: ``operand[..., start_index:start_index + slice_size]``.
Args:
operand: an array to slice.
start_index: the (possibly dynamic) start index
slice_size: the static slice size
axis: the axis along which to apply the slice (defaults to 0)
Returns:
An array containing the slice.
Examples:
Here is a one-dimensional example:
>>> x = jnp.arange(5)
>>> dynamic_slice_in_dim(x, 1, 3)
Array([1, 2, 3], dtype=int32)
Like `jax.lax.dynamic_slice`, out-of-bound slices will be clipped to the
valid range:
>>> dynamic_slice_in_dim(x, 4, 3)
Array([2, 3, 4], dtype=int32)
Here is a two-dimensional example:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice_in_dim(x, 1, 2, axis=1)
Array([[ 1, 2],
[ 5, 6],
[ 9, 10]], dtype=int32)
See Also:
- :attr:`jax.numpy.ndarray.at`
- :func:`jax.lax.slice_in_dim`
- :func:`jax.lax.dynamic_slice`
- :func:`jax.lax.dynamic_index_in_dim`
"""
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
slice_sizes = list(operand.shape)
@ -704,7 +894,48 @@ def dynamic_slice_in_dim(operand: Union[Array, np.ndarray],
def dynamic_index_in_dim(operand: Union[Array, np.ndarray],
index: Union[int, Array],
axis: int = 0, keepdims: bool = True) -> Array:
"""Convenience wrapper around dynamic_slice to perform int indexing."""
"""Convenience wrapper around dynamic_slice to perform int indexing.
This is roughly equivalent to the following Python indexing syntax applied
along the specified axis: ``operand[..., index]``.
Args:
operand: an array to slice.
index: the (possibly dynamic) start index
axis: the axis along which to apply the slice (defaults to 0)
keepdims: boolean specifying whether the output should have the same rank as
the input (default = True)
Returns:
An array containing the slice.
Examples:
Here is a one-dimensional example:
>>> x = jnp.arange(5)
>>> dynamic_index_in_dim(x, 1)
Array([1], dtype=int32)
>>> dynamic_index_in_dim(x, 1, keepdims=False)
Array(1, dtype=int32)
Here is a two-dimensional example:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)
See Also:
- :attr:`jax.numpy.ndarray.at`
- :func:`jax.lax.index_in_dim`
- :func:`jax.lax.dynamic_slice`
- :func:`jax.lax.dynamic_slice_in_dim`
"""
result = dynamic_slice_in_dim(operand, index, 1, axis)
if keepdims:
return result
@ -715,8 +946,58 @@ def dynamic_index_in_dim(operand: Union[Array, np.ndarray],
def dynamic_update_slice_in_dim(operand: Union[Array, np.ndarray],
update: ArrayLike,
start_index: ArrayLike, axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
in a single ``axis``.
"""Convenience wrapper around :func:`dynamic_update_slice` to update
a slice in a single ``axis``.
Args:
operand: an array to slice.
update: an array containing the new values to write onto `operand`.
start_index: a single scalar index
axis: the axis of the update.
Returns:
The updated array
Examples:
>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 1., 1., 0.], dtype=float32)
If the update slice is too large to fit in the array, the start
index will be adjusted to make it fit:
>>> dynamic_update_slice_in_dim(x, y, 3, axis=0)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice_in_dim(x, y, 5, axis=0)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
Here is an example of a two-dimensional slice update:
>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 4))
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]], dtype=float32)
Note that the shape of the additional axes in ``update`` need not
match the associated dimensions of the ``operand``:
>>> y = jnp.ones((2, 3))
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
[1., 1., 1., 0.],
[1., 1., 1., 0.],
[0., 0., 0., 0.]], dtype=float32)
See Also:
- :attr:`jax.numpy.ndarray.at`
- :func:`jax.lax.dynamic_update_slice`
- :func:`jax.lax.dynamic_update_index_in_dim`
- :func:`jax.lax.dynamic_slice_in_dim`
"""
axis = int(axis)
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
@ -728,7 +1009,59 @@ def dynamic_update_index_in_dim(operand: Union[Array, np.ndarray],
update: ArrayLike, index: ArrayLike,
axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
of size 1 in a single ``axis``.
of size 1 in a single ``axis``.
Args:
operand: an array to slice.
update: an array containing the new values to write onto `operand`.
index: a single scalar index
axis: the axis of the update.
Returns:
The updated array
Examples:
>>> x = jnp.zeros(6)
>>> y = 1.0
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)
>>> y = jnp.array([1.0])
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)
If the specified index is out of bounds, the index will be clipped to the
valid range:
>>> dynamic_update_index_in_dim(x, y, 10, axis=0)
Array([0., 0., 0., 0., 0., 1.], dtype=float32)
Here is an example of a two-dimensional dynamic index update:
>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones(4)
>>> dynamic_update_index_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)
Note that the shape of the additional axes in ``update`` need not
match the associated dimensions of the ``operand``:
>>> y = jnp.ones((1, 3))
>>> dynamic_update_index_in_dim(x, y, 1, 0)
Array([[0., 0., 0., 0.],
[1., 1., 1., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)
See Also:
- :attr:`jax.numpy.ndarray.at`
- :func:`jax.lax.dynamic_update_slice`
- :func:`jax.lax.dynamic_update_index_in_dim`
- :func:`jax.lax.dynamic_index_in_dim`
"""
axis = int(axis)
if lax._ndim(update) != lax._ndim(operand):