mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #22403 from jakevdp:repeat-doc
PiperOrigin-RevId: 651525013
This commit is contained in:
commit
8f24bc2fe3
@ -4464,19 +4464,71 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
return stack(output, 0) if output else array([], dtype=dtype)
|
||||
|
||||
|
||||
_TOTAL_REPEAT_LENGTH_DOC = """\
|
||||
JAX adds the optional `total_repeat_length` parameter which specifies the total
|
||||
number of repeat, and defaults to sum(repeats). It must be specified for repeat
|
||||
to be compilable. If `sum(repeats)` is larger than the specified
|
||||
`total_repeat_length` the remaining values will be discarded. In the case of
|
||||
`sum(repeats)` being smaller than the specified target length, the final value
|
||||
will be repeated.
|
||||
"""
|
||||
|
||||
|
||||
@util.implements(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
|
||||
def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
||||
total_repeat_length: int | None = None) -> Array:
|
||||
"""Construct an array from repeated elements.
|
||||
|
||||
JAX implementation of :func:`numpy.repeat`.
|
||||
|
||||
Args:
|
||||
a: N-dimensional array
|
||||
repeats: 1D integer array specifying the number of repeats. Must match the
|
||||
length of the repeated axis.
|
||||
axis: integer specifying the axis of ``a`` along which to construct the
|
||||
repeated array. If None (default) then ``a`` is first flattened.
|
||||
total_repeat_length: this must be specified statically for ``jnp.repeat``
|
||||
to be compatible with :func:`~jax.jit` and other JAX transformations.
|
||||
If ``sum(repeats)`` is larger than the specified ``total_repeat_length``,
|
||||
the remaining values will be discarded. If ``sum(repeats)`` is smaller
|
||||
than ``total_repeat_length``, the final value will be repeated.
|
||||
|
||||
Returns:
|
||||
an array constructed from repeated values of ``a``.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.tile`: repeat a full array rather than individual values.
|
||||
|
||||
Examples:
|
||||
Repeat each value twice along the last axis:
|
||||
|
||||
>>> a = jnp.array([[1, 2],
|
||||
... [3, 4]])
|
||||
>>> jnp.repeat(a, 2, axis=-1)
|
||||
Array([[1, 1, 2, 2],
|
||||
[3, 3, 4, 4]], dtype=int32)
|
||||
|
||||
If ``axis`` is not specified, the input array will be flattened:
|
||||
|
||||
>>> jnp.repeat(a, 2)
|
||||
Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
|
||||
|
||||
Pass an array to ``repeats`` to repeat each value a different number of times:
|
||||
|
||||
>>> repeats = jnp.array([2, 3])
|
||||
>>> jnp.repeat(a, repeats, axis=1)
|
||||
Array([[1, 1, 2, 2, 2],
|
||||
[3, 3, 4, 4, 4]], dtype=int32)
|
||||
|
||||
In order to use ``repeat`` within ``jit`` and other JAX transformations, the
|
||||
size of the output must be specified statically using ``total_repeat_length``:
|
||||
|
||||
>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length'])
|
||||
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=5)
|
||||
Array([[1, 1, 2, 2, 2],
|
||||
[3, 3, 4, 4, 4]], dtype=int32)
|
||||
|
||||
If `total_repeat_length` is smaller than ``sum(repeats)``, the result will be truncated:
|
||||
|
||||
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4)
|
||||
Array([[1, 1, 2, 2],
|
||||
[3, 3, 4, 4]], dtype=int32)
|
||||
|
||||
If it is larger, then the additional entries will be filled with the final value:
|
||||
|
||||
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7)
|
||||
Array([[1, 1, 2, 2, 2, 2, 2],
|
||||
[3, 3, 4, 4, 4, 4, 4]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("repeat", a)
|
||||
core.is_dim(repeats) or util.check_arraylike("repeat", repeats)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user