Merge pull request #24657 from dymil:patch-1

PiperOrigin-RevId: 694536516
This commit is contained in:
jax authors 2024-11-08 09:48:24 -08:00
commit 2b55bd5a24

View File

@ -10185,18 +10185,18 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array:
"""Return the index of the minimum value of an array.
JAX implementation of :func:`numpy.argmax`.
JAX implementation of :func:`numpy.argmin`.
Args:
a: input array
axis: optional integer specifying the axis along which to find the maximum
axis: optional integer specifying the axis along which to find the minimum
value. If ``axis`` is not specified, ``a`` will be flattened.
out: unused by JAX
keepdims: if True, then return an array with the same number of dimensions
as ``a``.
Returns:
an array containing the index of the maximum value along the specified axis.
an array containing the index of the minimum value along the specified axis.
See also:
- :func:`jax.numpy.argmax`: return the index of the maximum value.