mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
tile_docstring_added
This commit is contained in:
parent
1df58b1854
commit
24b6f50938
@ -4483,8 +4483,38 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
|
||||
)
|
||||
return tuple(moveaxis(x, axis, 0))
|
||||
|
||||
@util.implements(np.tile)
|
||||
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
|
||||
"""Construct an array by repeating ``A`` along specified dimensions.
|
||||
|
||||
JAX implementation of :func:`numpy.tile`.
|
||||
|
||||
If ``A`` is an array of shape ``(d1, d2, ..., dn)`` and ``reps`` is a sequence of integers,
|
||||
the resulting array will have a shape of ``(reps[0] * d1, reps[1] * d2, ..., reps[n] * dn)``,
|
||||
with ``A`` tiled along each dimension.
|
||||
|
||||
Args:
|
||||
A: input array to be repeated. Can be of any shape or dimension.
|
||||
reps: specifies the number of repetitions along each axis.
|
||||
|
||||
Returns:
|
||||
a new array where the input array has been repeated according to ``reps``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.repeat`: Construct an array from repeated elements.
|
||||
- :func:`jax.numpy.broadcast_to`: Broadcast an array to a specified shape.
|
||||
|
||||
Examples:
|
||||
>>> arr = jnp.array([1, 2])
|
||||
>>> jnp.tile(arr, 2)
|
||||
Array([1, 2, 1, 2], dtype=int32)
|
||||
>>> arr = jnp.array([[1, 2],
|
||||
... [3, 4,]])
|
||||
>>> jnp.tile(arr, (2, 1))
|
||||
Array([[1, 2],
|
||||
[3, 4],
|
||||
[1, 2],
|
||||
[3, 4]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("tile", A)
|
||||
try:
|
||||
iter(reps) # type: ignore[arg-type]
|
||||
|
Loading…
x
Reference in New Issue
Block a user