mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16642 from jakevdp:slice-in-dim
PiperOrigin-RevId: 546044892
This commit is contained in:
commit
f08e52faef
@ -54,6 +54,50 @@ def slice(operand: ArrayLike, start_indices: Sequence[int],
|
||||
"""Wraps XLA's `Slice
|
||||
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
|
||||
operator.
|
||||
|
||||
Args:
|
||||
operand: an array to slice
|
||||
start_indices: a sequence of ``operand.ndim`` start indices.
|
||||
limit_indices: a sequence of ``operand.ndim`` limit indices.
|
||||
strides: an optional sequence of ``operand.ndim`` strides.
|
||||
|
||||
Returns:
|
||||
The sliced array
|
||||
|
||||
Examples:
|
||||
Here are some examples of simple two-dimensional slices:
|
||||
|
||||
>>> x = jnp.arange(12).reshape(3, 4)
|
||||
>>> x
|
||||
Array([[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=int32)
|
||||
|
||||
>>> lax.slice(x, (1, 0), (3, 2))
|
||||
Array([[4, 5],
|
||||
[8, 9]], dtype=int32)
|
||||
|
||||
>>> lax.slice(x, (0, 0), (3, 4), (1, 2))
|
||||
Array([[ 0, 2],
|
||||
[ 4, 6],
|
||||
[ 8, 10]], dtype=int32)
|
||||
|
||||
These two examples are equivalent to the following Python slicing syntax:
|
||||
|
||||
>>> x[1:3, 0:2]
|
||||
Array([[4, 5],
|
||||
[8, 9]], dtype=int32)
|
||||
|
||||
>>> x[0:3, 0:4:2]
|
||||
Array([[ 0, 2],
|
||||
[ 4, 6],
|
||||
[ 8, 10]], dtype=int32)
|
||||
|
||||
See Also:
|
||||
- :attr:`jax.numpy.ndarray.at`
|
||||
- :func:`jax.lax.slice_in_dim`
|
||||
- :func:`jax.lax.index_in_dim`
|
||||
- :func:`jax.lax.dynamic_slice`
|
||||
"""
|
||||
return slice_p.bind(operand, start_indices=tuple(start_indices),
|
||||
limit_indices=tuple(limit_indices),
|
||||
@ -101,6 +145,12 @@ def dynamic_slice(
|
||||
>>> dynamic_slice(x, (1, 1), (2, 4))
|
||||
Array([[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=int32)
|
||||
|
||||
See Also:
|
||||
- :attr:`jax.numpy.ndarray.at`
|
||||
- :func:`jax.lax.slice`
|
||||
- :func:`jax.lax.dynamic_slice_in_dim`
|
||||
- :func:`jax.lax.dynamic_index_in_dim`
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
@ -151,6 +201,11 @@ def dynamic_update_slice(operand: Union[Array, np.ndarray], update: ArrayLike,
|
||||
[0., 0., 1., 1.],
|
||||
[0., 0., 1., 1.],
|
||||
[0., 0., 0., 0.]], dtype=float32)
|
||||
|
||||
See Also:
|
||||
- :attr:`jax.numpy.ndarray.at`
|
||||
- :attr:`lax.dynamic_update_index_in_dim`
|
||||
- :attr:`lax.dynamic_update_slice_in_dim`
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
return dynamic_update_slice_p.bind(operand, update, *start_indices)
|
||||
@ -646,7 +701,53 @@ def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array:
|
||||
def slice_in_dim(operand: Union[Array, np.ndarray], start_index: Optional[int],
|
||||
limit_index: Optional[int],
|
||||
stride: int = 1, axis: int = 0) -> Array:
|
||||
"""Convenience wrapper around slice applying to only one dimension."""
|
||||
"""Convenience wrapper around :func:`lax.slice` applying to only one dimension.
|
||||
|
||||
This is effectively equivalent to ``operand[..., start_index:limit_index:stride]``
|
||||
with the indexing applied on the specified axis.
|
||||
|
||||
Args:
|
||||
operand: an array to slice.
|
||||
start_index: an optional start index (defaults to zero)
|
||||
limit_index: an optional end index (defaults to operand.shape[axis])
|
||||
stride: an optional stride (defaults to 1)
|
||||
axis: the axis along which to apply the slice (defaults to 0)
|
||||
|
||||
Returns:
|
||||
An array containing the slice.
|
||||
|
||||
Examples:
|
||||
Here is a one-dimensional example:
|
||||
|
||||
>>> x = jnp.arange(4)
|
||||
>>> lax.slice_in_dim(x, 1, 3)
|
||||
Array([1, 2], dtype=int32)
|
||||
|
||||
Here are some two-dimensional examples:
|
||||
|
||||
>>> x = jnp.arange(12).reshape(4, 3)
|
||||
>>> x
|
||||
Array([[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
[ 6, 7, 8],
|
||||
[ 9, 10, 11]], dtype=int32)
|
||||
|
||||
>>> lax.slice_in_dim(x, 1, 3)
|
||||
Array([[3, 4, 5],
|
||||
[6, 7, 8]], dtype=int32)
|
||||
|
||||
>>> lax.slice_in_dim(x, 1, 3, axis=1)
|
||||
Array([[ 1, 2],
|
||||
[ 4, 5],
|
||||
[ 7, 8],
|
||||
[10, 11]], dtype=int32)
|
||||
|
||||
See Also:
|
||||
- :attr:`jax.numpy.ndarray.at`
|
||||
- :func:`jax.lax.slice`
|
||||
- :func:`jax.lax.index_in_dim`
|
||||
- :func:`jax.lax.dynamic_slice_in_dim`
|
||||
"""
|
||||
start_indices = [0] * operand.ndim
|
||||
limit_indices = list(operand.shape)
|
||||
strides = [1] * operand.ndim
|
||||
@ -674,7 +775,51 @@ def slice_in_dim(operand: Union[Array, np.ndarray], start_index: Optional[int],
|
||||
|
||||
def index_in_dim(operand: Union[Array, np.ndarray], index: int, axis: int = 0,
|
||||
keepdims: bool = True) -> Array:
|
||||
"""Convenience wrapper around slice to perform int indexing."""
|
||||
"""Convenience wrapper around :func:`lax.slice` to perform int indexing.
|
||||
|
||||
This is effectively equivalent to ``operand[..., start_index:limit_index:stride]``
|
||||
with the indexing applied on the specified axis.
|
||||
|
||||
Args:
|
||||
operand: an array to index.
|
||||
index: integer index
|
||||
axis: the axis along which to apply the index (defaults to 0)
|
||||
keepdims: boolean specifying whether the output array should preserve the
|
||||
rank of the input (default=True)
|
||||
|
||||
Returns:
|
||||
The subarray at the specified index.
|
||||
|
||||
Examples:
|
||||
Here is a one-dimensional example:
|
||||
|
||||
>>> x = jnp.arange(4)
|
||||
>>> lax.index_in_dim(x, 2)
|
||||
Array([2], dtype=int32)
|
||||
|
||||
>>> lax.index_in_dim(x, 2, keepdims=False)
|
||||
Array(2, dtype=int32)
|
||||
|
||||
Here are some two-dimensional examples:
|
||||
|
||||
>>> x = jnp.arange(12).reshape(3, 4)
|
||||
>>> x
|
||||
Array([[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=int32)
|
||||
|
||||
>>> lax.index_in_dim(x, 1)
|
||||
Array([[4, 5, 6, 7]], dtype=int32)
|
||||
|
||||
>>> lax.index_in_dim(x, 1, axis=1, keepdims=False)
|
||||
Array([1, 5, 9], dtype=int32)
|
||||
|
||||
See Also:
|
||||
- :attr:`jax.numpy.ndarray.at`
|
||||
- :func:`jax.lax.slice`
|
||||
- :func:`jax.lax.slice_in_dim`
|
||||
- :func:`jax.lax.dynamic_index_in_dim`
|
||||
"""
|
||||
index, axis = core._canonicalize_dimension(index), int(axis)
|
||||
axis_size = operand.shape[axis]
|
||||
wrapped_index = index + axis_size if index < 0 else index
|
||||
@ -691,7 +836,52 @@ def index_in_dim(operand: Union[Array, np.ndarray], index: int, axis: int = 0,
|
||||
def dynamic_slice_in_dim(operand: Union[Array, np.ndarray],
|
||||
start_index: ArrayLike,
|
||||
slice_size: int, axis: int = 0) -> Array:
|
||||
"""Convenience wrapper around dynamic_slice applying to one dimension."""
|
||||
"""Convenience wrapper around :func:`lax.dynamic_slice` applied to one dimension.
|
||||
|
||||
This is roughly equivalent to the following Python indexing syntax applied
|
||||
along the specified axis: ``operand[..., start_index:start_index + slice_size]``.
|
||||
|
||||
Args:
|
||||
operand: an array to slice.
|
||||
start_index: the (possibly dynamic) start index
|
||||
slice_size: the static slice size
|
||||
axis: the axis along which to apply the slice (defaults to 0)
|
||||
|
||||
Returns:
|
||||
An array containing the slice.
|
||||
|
||||
Examples:
|
||||
Here is a one-dimensional example:
|
||||
|
||||
>>> x = jnp.arange(5)
|
||||
>>> dynamic_slice_in_dim(x, 1, 3)
|
||||
Array([1, 2, 3], dtype=int32)
|
||||
|
||||
Like `jax.lax.dynamic_slice`, out-of-bound slices will be clipped to the
|
||||
valid range:
|
||||
|
||||
>>> dynamic_slice_in_dim(x, 4, 3)
|
||||
Array([2, 3, 4], dtype=int32)
|
||||
|
||||
Here is a two-dimensional example:
|
||||
|
||||
>>> x = jnp.arange(12).reshape(3, 4)
|
||||
>>> x
|
||||
Array([[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=int32)
|
||||
|
||||
>>> dynamic_slice_in_dim(x, 1, 2, axis=1)
|
||||
Array([[ 1, 2],
|
||||
[ 5, 6],
|
||||
[ 9, 10]], dtype=int32)
|
||||
|
||||
See Also:
|
||||
- :attr:`jax.numpy.ndarray.at`
|
||||
- :func:`jax.lax.slice_in_dim`
|
||||
- :func:`jax.lax.dynamic_slice`
|
||||
- :func:`jax.lax.dynamic_index_in_dim`
|
||||
"""
|
||||
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
|
||||
slice_sizes = list(operand.shape)
|
||||
|
||||
@ -704,7 +894,48 @@ def dynamic_slice_in_dim(operand: Union[Array, np.ndarray],
|
||||
def dynamic_index_in_dim(operand: Union[Array, np.ndarray],
|
||||
index: Union[int, Array],
|
||||
axis: int = 0, keepdims: bool = True) -> Array:
|
||||
"""Convenience wrapper around dynamic_slice to perform int indexing."""
|
||||
"""Convenience wrapper around dynamic_slice to perform int indexing.
|
||||
|
||||
This is roughly equivalent to the following Python indexing syntax applied
|
||||
along the specified axis: ``operand[..., index]``.
|
||||
|
||||
Args:
|
||||
operand: an array to slice.
|
||||
index: the (possibly dynamic) start index
|
||||
axis: the axis along which to apply the slice (defaults to 0)
|
||||
keepdims: boolean specifying whether the output should have the same rank as
|
||||
the input (default = True)
|
||||
|
||||
Returns:
|
||||
An array containing the slice.
|
||||
|
||||
Examples:
|
||||
Here is a one-dimensional example:
|
||||
|
||||
>>> x = jnp.arange(5)
|
||||
>>> dynamic_index_in_dim(x, 1)
|
||||
Array([1], dtype=int32)
|
||||
|
||||
>>> dynamic_index_in_dim(x, 1, keepdims=False)
|
||||
Array(1, dtype=int32)
|
||||
|
||||
Here is a two-dimensional example:
|
||||
|
||||
>>> x = jnp.arange(12).reshape(3, 4)
|
||||
>>> x
|
||||
Array([[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=int32)
|
||||
|
||||
>>> dynamic_index_in_dim(x, 1, axis=1, keepdims=False)
|
||||
Array([1, 5, 9], dtype=int32)
|
||||
|
||||
See Also:
|
||||
- :attr:`jax.numpy.ndarray.at`
|
||||
- :func:`jax.lax.index_in_dim`
|
||||
- :func:`jax.lax.dynamic_slice`
|
||||
- :func:`jax.lax.dynamic_slice_in_dim`
|
||||
"""
|
||||
result = dynamic_slice_in_dim(operand, index, 1, axis)
|
||||
if keepdims:
|
||||
return result
|
||||
@ -715,8 +946,58 @@ def dynamic_index_in_dim(operand: Union[Array, np.ndarray],
|
||||
def dynamic_update_slice_in_dim(operand: Union[Array, np.ndarray],
|
||||
update: ArrayLike,
|
||||
start_index: ArrayLike, axis: int) -> Array:
|
||||
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
|
||||
in a single ``axis``.
|
||||
"""Convenience wrapper around :func:`dynamic_update_slice` to update
|
||||
a slice in a single ``axis``.
|
||||
|
||||
Args:
|
||||
operand: an array to slice.
|
||||
update: an array containing the new values to write onto `operand`.
|
||||
start_index: a single scalar index
|
||||
axis: the axis of the update.
|
||||
|
||||
Returns:
|
||||
The updated array
|
||||
|
||||
Examples:
|
||||
|
||||
>>> x = jnp.zeros(6)
|
||||
>>> y = jnp.ones(3)
|
||||
>>> dynamic_update_slice_in_dim(x, y, 2, axis=0)
|
||||
Array([0., 0., 1., 1., 1., 0.], dtype=float32)
|
||||
|
||||
If the update slice is too large to fit in the array, the start
|
||||
index will be adjusted to make it fit:
|
||||
|
||||
>>> dynamic_update_slice_in_dim(x, y, 3, axis=0)
|
||||
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
|
||||
>>> dynamic_update_slice_in_dim(x, y, 5, axis=0)
|
||||
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
|
||||
|
||||
Here is an example of a two-dimensional slice update:
|
||||
|
||||
>>> x = jnp.zeros((4, 4))
|
||||
>>> y = jnp.ones((2, 4))
|
||||
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
|
||||
Array([[0., 0., 0., 0.],
|
||||
[1., 1., 1., 1.],
|
||||
[1., 1., 1., 1.],
|
||||
[0., 0., 0., 0.]], dtype=float32)
|
||||
|
||||
Note that the shape of the additional axes in ``update`` need not
|
||||
match the associated dimensions of the ``operand``:
|
||||
|
||||
>>> y = jnp.ones((2, 3))
|
||||
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
|
||||
Array([[0., 0., 0., 0.],
|
||||
[1., 1., 1., 0.],
|
||||
[1., 1., 1., 0.],
|
||||
[0., 0., 0., 0.]], dtype=float32)
|
||||
|
||||
See Also:
|
||||
- :attr:`jax.numpy.ndarray.at`
|
||||
- :func:`jax.lax.dynamic_update_slice`
|
||||
- :func:`jax.lax.dynamic_update_index_in_dim`
|
||||
- :func:`jax.lax.dynamic_slice_in_dim`
|
||||
"""
|
||||
axis = int(axis)
|
||||
start_indices: list[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
|
||||
@ -728,7 +1009,59 @@ def dynamic_update_index_in_dim(operand: Union[Array, np.ndarray],
|
||||
update: ArrayLike, index: ArrayLike,
|
||||
axis: int) -> Array:
|
||||
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
|
||||
of size 1 in a single ``axis``.
|
||||
of size 1 in a single ``axis``.
|
||||
|
||||
Args:
|
||||
operand: an array to slice.
|
||||
update: an array containing the new values to write onto `operand`.
|
||||
index: a single scalar index
|
||||
axis: the axis of the update.
|
||||
|
||||
Returns:
|
||||
The updated array
|
||||
|
||||
Examples:
|
||||
|
||||
>>> x = jnp.zeros(6)
|
||||
>>> y = 1.0
|
||||
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
|
||||
Array([0., 0., 1., 0., 0., 0.], dtype=float32)
|
||||
|
||||
>>> y = jnp.array([1.0])
|
||||
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
|
||||
Array([0., 0., 1., 0., 0., 0.], dtype=float32)
|
||||
|
||||
If the specified index is out of bounds, the index will be clipped to the
|
||||
valid range:
|
||||
|
||||
>>> dynamic_update_index_in_dim(x, y, 10, axis=0)
|
||||
Array([0., 0., 0., 0., 0., 1.], dtype=float32)
|
||||
|
||||
Here is an example of a two-dimensional dynamic index update:
|
||||
|
||||
>>> x = jnp.zeros((4, 4))
|
||||
>>> y = jnp.ones(4)
|
||||
>>> dynamic_update_index_in_dim(x, y, 1, axis=0)
|
||||
Array([[0., 0., 0., 0.],
|
||||
[1., 1., 1., 1.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.]], dtype=float32)
|
||||
|
||||
Note that the shape of the additional axes in ``update`` need not
|
||||
match the associated dimensions of the ``operand``:
|
||||
|
||||
>>> y = jnp.ones((1, 3))
|
||||
>>> dynamic_update_index_in_dim(x, y, 1, 0)
|
||||
Array([[0., 0., 0., 0.],
|
||||
[1., 1., 1., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.]], dtype=float32)
|
||||
|
||||
See Also:
|
||||
- :attr:`jax.numpy.ndarray.at`
|
||||
- :func:`jax.lax.dynamic_update_slice`
|
||||
- :func:`jax.lax.dynamic_update_index_in_dim`
|
||||
- :func:`jax.lax.dynamic_index_in_dim`
|
||||
"""
|
||||
axis = int(axis)
|
||||
if lax._ndim(update) != lax._ndim(operand):
|
||||
|
Loading…
x
Reference in New Issue
Block a user