Merge pull request #22403 from jakevdp:repeat-doc

PiperOrigin-RevId: 651525013
This commit is contained in:
jax authors 2024-07-11 13:54:00 -07:00
commit 8f24bc2fe3

View File

@ -4464,19 +4464,71 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
return stack(output, 0) if output else array([], dtype=dtype)
_TOTAL_REPEAT_LENGTH_DOC = """\
JAX adds the optional `total_repeat_length` parameter which specifies the total
number of repeat, and defaults to sum(repeats). It must be specified for repeat
to be compilable. If `sum(repeats)` is larger than the specified
`total_repeat_length` the remaining values will be discarded. In the case of
`sum(repeats)` being smaller than the specified target length, the final value
will be repeated.
"""
@util.implements(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
total_repeat_length: int | None = None) -> Array:
"""Construct an array from repeated elements.
JAX implementation of :func:`numpy.repeat`.
Args:
a: N-dimensional array
repeats: 1D integer array specifying the number of repeats. Must match the
length of the repeated axis.
axis: integer specifying the axis of ``a`` along which to construct the
repeated array. If None (default) then ``a`` is first flattened.
total_repeat_length: this must be specified statically for ``jnp.repeat``
to be compatible with :func:`~jax.jit` and other JAX transformations.
If ``sum(repeats)`` is larger than the specified ``total_repeat_length``,
the remaining values will be discarded. If ``sum(repeats)`` is smaller
than ``total_repeat_length``, the final value will be repeated.
Returns:
an array constructed from repeated values of ``a``.
See Also:
- :func:`jax.numpy.tile`: repeat a full array rather than individual values.
Examples:
Repeat each value twice along the last axis:
>>> a = jnp.array([[1, 2],
... [3, 4]])
>>> jnp.repeat(a, 2, axis=-1)
Array([[1, 1, 2, 2],
[3, 3, 4, 4]], dtype=int32)
If ``axis`` is not specified, the input array will be flattened:
>>> jnp.repeat(a, 2)
Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
Pass an array to ``repeats`` to repeat each value a different number of times:
>>> repeats = jnp.array([2, 3])
>>> jnp.repeat(a, repeats, axis=1)
Array([[1, 1, 2, 2, 2],
[3, 3, 4, 4, 4]], dtype=int32)
In order to use ``repeat`` within ``jit`` and other JAX transformations, the
size of the output must be specified statically using ``total_repeat_length``:
>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length'])
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=5)
Array([[1, 1, 2, 2, 2],
[3, 3, 4, 4, 4]], dtype=int32)
If `total_repeat_length` is smaller than ``sum(repeats)``, the result will be truncated:
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4)
Array([[1, 1, 2, 2],
[3, 3, 4, 4]], dtype=int32)
If it is larger, then the additional entries will be filled with the final value:
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7)
Array([[1, 1, 2, 2, 2, 2, 2],
[3, 3, 4, 4, 4, 4, 4]], dtype=int32)
"""
util.check_arraylike("repeat", a)
core.is_dim(repeats) or util.check_arraylike("repeat", repeats)