mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23927 from jakevdp:pad-doc
PiperOrigin-RevId: 679142773
This commit is contained in:
commit
c07652fd46
@ -3750,13 +3750,123 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str,
|
||||
"not implemented modes")
|
||||
|
||||
|
||||
@util.implements(np.pad, lax_description="""\
|
||||
Unlike numpy, JAX "function" mode's argument (which is another function) should return
|
||||
the modified array. This is because Jax arrays are immutable.
|
||||
(In numpy, "function" mode's argument should modify a rank 1 array in-place.)
|
||||
""")
|
||||
def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray],
|
||||
mode: str | Callable[..., Any] = "constant", **kwargs) -> Array:
|
||||
"""Add padding to an array.
|
||||
|
||||
JAX implementation of :func:`numpy.pad`.
|
||||
|
||||
Args:
|
||||
array: array to pad.
|
||||
pad_width: specify the pad width for each dimension of an array. Padding widths
|
||||
may be separately specified for *before* and *after* the array. Options are:
|
||||
|
||||
- ``int`` or ``(int,)``: pad each array dimension with the same number of values
|
||||
both before and after.
|
||||
- ``(before, after)``: pad each array with ``before`` elements before, and ``after``
|
||||
elements after
|
||||
- ``((before_1, after_1), (before_2, after_2), ... (before_N, after_N))``: specify
|
||||
distinct ``before`` and ``after`` values for each array dimension.
|
||||
|
||||
mode: a string or callable. Supported pad modes are:
|
||||
|
||||
- ``'constant'`` (default): pad with a constant value, which defaults to zero.
|
||||
- ``'empty'``: pad with empty values (i.e. zero)
|
||||
- ``'edge'``: pad with the edge values of the array.
|
||||
- ``'wrap'``: pad by wrapping the array.
|
||||
- ``'linear_ramp'``: pad with a linear ramp to specified ``end_values``.
|
||||
- ``'maximum'``: pad with the maximum value.
|
||||
- ``'mean'``: pad with the mean value.
|
||||
- ``'median'``: pad with the median value.
|
||||
- ``'minimum'``: pad with the minimum value.
|
||||
- ``'reflect'``: pad by reflection.
|
||||
- ``'symmetric'``: pad by symmetric reflection.
|
||||
- ``<callable>``: a callable function. See Notes below.
|
||||
|
||||
constant_values: referenced for ``mode = 'constant'``. Specify the constant value
|
||||
to pad with.
|
||||
stat_length: referenced for ``mode in ['maximum', 'mean', 'median', 'minimum']``.
|
||||
An integer or tuple specifying the number of edge values to use when calculating
|
||||
the statistic.
|
||||
end_values: referenced for ``mode = 'linear_ramp'``. Specify the end values to
|
||||
ramp the padding values to.
|
||||
reflect_type: referenced for ``mode in ['reflect', 'symmetric']``. Specify whether
|
||||
to use even or odd reflection.
|
||||
|
||||
Returns:
|
||||
A padded copy of ``array``.
|
||||
|
||||
Notes:
|
||||
When ``mode`` is callable, it should have the following signature::
|
||||
|
||||
def pad_func(row: Array, pad_width: tuple[int, int],
|
||||
iaxis: int, kwargs: dict) -> Array:
|
||||
...
|
||||
|
||||
Here ``row`` is a 1D slice of the padded array along axis ``iaxis``, with the pad
|
||||
values filled with zeros. ``pad_width`` is a tuple specifying the ``(before, after)``
|
||||
padding sizes, and ``kwargs`` are any additional keyword arguments passed to the
|
||||
:func:`jax.numpy.pad` function.
|
||||
|
||||
Note that while in NumPy, the function should modify ``row`` in-place, in JAX the
|
||||
function should return the modified ``row``. In JAX, the custom padding function
|
||||
will be mapped across the padded axis using the :func:`jax.vmap` transformation.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.resize`: resize an array
|
||||
- :func:`jax.numpy.tile`: create a larger array by tiling a smaller array.
|
||||
- :func:`jax.numpy.repeat`: create a larger array by repeating values of a smaller array.
|
||||
|
||||
Examples:
|
||||
|
||||
Pad a 1-dimensional array with zeros:
|
||||
|
||||
>>> x = jnp.array([10, 20, 30, 40])
|
||||
>>> jnp.pad(x, 2)
|
||||
Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32)
|
||||
>>> jnp.pad(x, (2, 4))
|
||||
Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32)
|
||||
|
||||
Pad a 1-dimensional array with specified values:
|
||||
|
||||
>>> jnp.pad(x, 2, constant_values=99)
|
||||
Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)
|
||||
|
||||
Pad a 1-dimensional array with the mean array value:
|
||||
|
||||
>>> jnp.pad(x, 2, mode='mean')
|
||||
Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)
|
||||
|
||||
Pad a 1-dimensional array with reflected values:
|
||||
|
||||
>>> jnp.pad(x, 2, mode='reflect')
|
||||
Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)
|
||||
|
||||
Pad a 2-dimensional array with different paddings in each dimension:
|
||||
|
||||
>>> x = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
>>> jnp.pad(x, ((1, 2), (3, 0)))
|
||||
Array([[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 2, 3],
|
||||
[0, 0, 0, 4, 5, 6],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]], dtype=int32)
|
||||
|
||||
Pad a 1-dimensional array with a custom padding function:
|
||||
|
||||
>>> def custom_pad(row, pad_width, iaxis, kwargs):
|
||||
... # row represents a 1D slice of the zero-padded array.
|
||||
... before, after = pad_width
|
||||
... before_value = kwargs.get('before_value', 0)
|
||||
... after_value = kwargs.get('after_value', 0)
|
||||
... row = row.at[:before].set(before_value)
|
||||
... return row.at[len(row) - after:].set(after_value)
|
||||
>>> x = jnp.array([2, 3, 4])
|
||||
>>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10)
|
||||
Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32)
|
||||
"""
|
||||
|
||||
util.check_arraylike("pad", array)
|
||||
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
|
||||
if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1])
|
||||
|
@ -114,7 +114,6 @@ def _parse_parameters(body: str) -> dict[str, str]:
|
||||
def implements(
|
||||
original_fun: Callable[..., Any] | None,
|
||||
update_doc: bool = True,
|
||||
lax_description: str = "",
|
||||
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
|
||||
skip_params: Sequence[str] = (),
|
||||
module: str | None = None,
|
||||
@ -132,8 +131,6 @@ def implements(
|
||||
update_doc: whether to transform the numpy docstring to remove references of
|
||||
parameters that are supported by the numpy version but not the JAX version.
|
||||
If False, include the numpy docstring verbatim.
|
||||
lax_description: a string description that will be added to the beginning of
|
||||
the docstring.
|
||||
sections: a list of sections to include in the docstring. The default is
|
||||
["Parameters", "Returns", "References"]
|
||||
skip_params: a list of strings containing names of parameters accepted by the
|
||||
@ -146,8 +143,6 @@ def implements(
|
||||
wrapped_fun.__np_wrapped__ = original_fun
|
||||
# Allows this pattern: @implements(getattr(np, 'new_function', None))
|
||||
if original_fun is None:
|
||||
if lax_description:
|
||||
wrapped_fun.__doc__ = lax_description
|
||||
return wrapped_fun
|
||||
docstr = getattr(original_fun, "__doc__", None)
|
||||
name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun)))
|
||||
@ -181,8 +176,6 @@ def implements(
|
||||
|
||||
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
|
||||
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
|
||||
if lax_description:
|
||||
docstr += "\n" + lax_description.strip() + "\n"
|
||||
docstr += "\n*Original docstring below.*\n"
|
||||
|
||||
# We remove signatures from the docstrings, because they redundant at best and
|
||||
|
Loading…
x
Reference in New Issue
Block a user