tile_docstring_added

This commit is contained in:
selamw1 2024-10-15 15:34:01 -07:00
parent 1df58b1854
commit 24b6f50938

View File

@ -4483,8 +4483,38 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
)
return tuple(moveaxis(x, axis, 0))
@util.implements(np.tile)
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
"""Construct an array by repeating ``A`` along specified dimensions.
JAX implementation of :func:`numpy.tile`.
If ``A`` is an array of shape ``(d1, d2, ..., dn)`` and ``reps`` is a sequence of integers,
the resulting array will have a shape of ``(reps[0] * d1, reps[1] * d2, ..., reps[n] * dn)``,
with ``A`` tiled along each dimension.
Args:
A: input array to be repeated. Can be of any shape or dimension.
reps: specifies the number of repetitions along each axis.
Returns:
a new array where the input array has been repeated according to ``reps``.
See also:
- :func:`jax.numpy.repeat`: Construct an array from repeated elements.
- :func:`jax.numpy.broadcast_to`: Broadcast an array to a specified shape.
Examples:
>>> arr = jnp.array([1, 2])
>>> jnp.tile(arr, 2)
Array([1, 2, 1, 2], dtype=int32)
>>> arr = jnp.array([[1, 2],
... [3, 4,]])
>>> jnp.tile(arr, (2, 1))
Array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)
"""
util.check_arraylike("tile", A)
try:
iter(reps) # type: ignore[arg-type]