Improve documentation of jax.numpy.clip

This commit is contained in:
Jake VanderPlas 2024-06-28 15:53:30 -07:00
parent d7bc1ac8d3
commit e2c139be80

View File

@ -2250,48 +2250,65 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
return _split("array_split", ary, indices_or_sections, axis=axis)
_DEPRECATED_CLIP_ARG = DeprecatedArg()
@util.implements(
np.clip,
skip_params=['a', 'a_min'],
extra_params=_dedent("""
x : array_like
Array containing elements to clip.
min : array_like, optional
Minimum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``min`` is broadcast against x.
max : array_like, optional
Maximum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``max`` is broadcast against x.
""")
)
@jit
def clip(
x: ArrayLike | None = None, # Default to preserve backwards compatability
arr: ArrayLike | None = None,
/,
min: ArrayLike | None = None,
max: ArrayLike | None = None,
*,
a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
a: ArrayLike | DeprecatedArg = DeprecatedArg(),
a_min: ArrayLike | None | DeprecatedArg = DeprecatedArg(),
a_max: ArrayLike | None | DeprecatedArg = DeprecatedArg()
) -> Array:
"""Clip array values to a specified range.
JAX implementation of :func:`numpy.clip`.
Args:
arr: N-dimensional array to be clipped.
min: optional minimum value of the clipped range; if ``None`` (default) then
result will not be clipped to any minimum value. If specified, it should be
broadcast-compatible with ``arr`` and ``max``.
max: optional maximum value of the clipped range; if ``None`` (default) then
result will not be clipped to any maximum value. If specified, it should be
broadcast-compatible with ``arr`` and ``min``.
a: deprecated alias of the ``arr`` argument. Will result in a
:class:`DeprecationWarning` if used.
a_min: deprecated alias of the ``min`` argument. Will result in a
:class:`DeprecationWarning` if used.
a_max: deprecated alias of the ``max`` argument. Will result in a
:class:`DeprecationWarning` if used.
Returns:
An array containing values from ``arr``, with values smaller than ``min`` set
to ``min``, and values larger than ``max`` set to ``max``.
See also:
- :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays.
- :func:`jax.numpy.maximum`: Compute the element-wise maximum value of two arrays.
Examples:
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
>>> jnp.clip(arr, 2, 5)
Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)
"""
# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
x = a if not isinstance(a, DeprecatedArg) else x
if x is None:
arr = a if not isinstance(a, DeprecatedArg) else arr
if arr is None:
raise ValueError("No input was provided to the clip function.")
min = a_min if not isinstance(a_min, DeprecatedArg) else min
max = a_max if not isinstance(a_max, DeprecatedArg) else max
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
warnings.warn(
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
"Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is "
"deprecated. Please use 'arr', 'min' or 'max' respectively instead.",
DeprecationWarning,
stacklevel=2,
)
util.check_arraylike("clip", x)
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
util.check_arraylike("clip", arr)
if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)):
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
warnings.warn(
"Clip received a complex value either through the input or the min/max "
@ -2302,10 +2319,10 @@ def clip(
DeprecationWarning, stacklevel=2,
)
if min is not None:
x = ufuncs.maximum(min, x)
arr = ufuncs.maximum(min, arr)
if max is not None:
x = ufuncs.minimum(max, x)
return asarray(x)
arr = ufuncs.minimum(max, arr)
return asarray(arr)
@util.implements(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))