mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve documentation of jax.numpy.clip
This commit is contained in:
parent
d7bc1ac8d3
commit
e2c139be80
@ -2250,48 +2250,65 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
|
||||
return _split("array_split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
|
||||
_DEPRECATED_CLIP_ARG = DeprecatedArg()
|
||||
@util.implements(
|
||||
np.clip,
|
||||
skip_params=['a', 'a_min'],
|
||||
extra_params=_dedent("""
|
||||
x : array_like
|
||||
Array containing elements to clip.
|
||||
min : array_like, optional
|
||||
Minimum value. If ``None``, clipping is not performed on the
|
||||
corresponding edge. The value of ``min`` is broadcast against x.
|
||||
max : array_like, optional
|
||||
Maximum value. If ``None``, clipping is not performed on the
|
||||
corresponding edge. The value of ``max`` is broadcast against x.
|
||||
""")
|
||||
)
|
||||
@jit
|
||||
def clip(
|
||||
x: ArrayLike | None = None, # Default to preserve backwards compatability
|
||||
arr: ArrayLike | None = None,
|
||||
/,
|
||||
min: ArrayLike | None = None,
|
||||
max: ArrayLike | None = None,
|
||||
*,
|
||||
a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
|
||||
a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
|
||||
a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
|
||||
a: ArrayLike | DeprecatedArg = DeprecatedArg(),
|
||||
a_min: ArrayLike | None | DeprecatedArg = DeprecatedArg(),
|
||||
a_max: ArrayLike | None | DeprecatedArg = DeprecatedArg()
|
||||
) -> Array:
|
||||
"""Clip array values to a specified range.
|
||||
|
||||
JAX implementation of :func:`numpy.clip`.
|
||||
|
||||
Args:
|
||||
arr: N-dimensional array to be clipped.
|
||||
min: optional minimum value of the clipped range; if ``None`` (default) then
|
||||
result will not be clipped to any minimum value. If specified, it should be
|
||||
broadcast-compatible with ``arr`` and ``max``.
|
||||
max: optional maximum value of the clipped range; if ``None`` (default) then
|
||||
result will not be clipped to any maximum value. If specified, it should be
|
||||
broadcast-compatible with ``arr`` and ``min``.
|
||||
a: deprecated alias of the ``arr`` argument. Will result in a
|
||||
:class:`DeprecationWarning` if used.
|
||||
a_min: deprecated alias of the ``min`` argument. Will result in a
|
||||
:class:`DeprecationWarning` if used.
|
||||
a_max: deprecated alias of the ``max`` argument. Will result in a
|
||||
:class:`DeprecationWarning` if used.
|
||||
|
||||
Returns:
|
||||
An array containing values from ``arr``, with values smaller than ``min`` set
|
||||
to ``min``, and values larger than ``max`` set to ``max``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays.
|
||||
- :func:`jax.numpy.maximum`: Compute the element-wise maximum value of two arrays.
|
||||
|
||||
Examples:
|
||||
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
>>> jnp.clip(arr, 2, 5)
|
||||
Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)
|
||||
"""
|
||||
# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
|
||||
x = a if not isinstance(a, DeprecatedArg) else x
|
||||
if x is None:
|
||||
arr = a if not isinstance(a, DeprecatedArg) else arr
|
||||
if arr is None:
|
||||
raise ValueError("No input was provided to the clip function.")
|
||||
min = a_min if not isinstance(a_min, DeprecatedArg) else min
|
||||
max = a_max if not isinstance(a_max, DeprecatedArg) else max
|
||||
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
|
||||
warnings.warn(
|
||||
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
|
||||
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
|
||||
"Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is "
|
||||
"deprecated. Please use 'arr', 'min' or 'max' respectively instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
util.check_arraylike("clip", x)
|
||||
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
|
||||
util.check_arraylike("clip", arr)
|
||||
if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)):
|
||||
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
|
||||
warnings.warn(
|
||||
"Clip received a complex value either through the input or the min/max "
|
||||
@ -2302,10 +2319,10 @@ def clip(
|
||||
DeprecationWarning, stacklevel=2,
|
||||
)
|
||||
if min is not None:
|
||||
x = ufuncs.maximum(min, x)
|
||||
arr = ufuncs.maximum(min, arr)
|
||||
if max is not None:
|
||||
x = ufuncs.minimum(max, x)
|
||||
return asarray(x)
|
||||
arr = ufuncs.minimum(max, arr)
|
||||
return asarray(arr)
|
||||
|
||||
@util.implements(np.around, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('decimals',))
|
||||
|
Loading…
x
Reference in New Issue
Block a user