mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #24657 from dymil:patch-1
PiperOrigin-RevId: 694536516
This commit is contained in:
commit
2b55bd5a24
@ -10185,18 +10185,18 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None,
|
||||
keepdims: bool | None = None) -> Array:
|
||||
"""Return the index of the minimum value of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.argmax`.
|
||||
JAX implementation of :func:`numpy.argmin`.
|
||||
|
||||
Args:
|
||||
a: input array
|
||||
axis: optional integer specifying the axis along which to find the maximum
|
||||
axis: optional integer specifying the axis along which to find the minimum
|
||||
value. If ``axis`` is not specified, ``a`` will be flattened.
|
||||
out: unused by JAX
|
||||
keepdims: if True, then return an array with the same number of dimensions
|
||||
as ``a``.
|
||||
|
||||
Returns:
|
||||
an array containing the index of the maximum value along the specified axis.
|
||||
an array containing the index of the minimum value along the specified axis.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.argmax`: return the index of the maximum value.
|
||||
|
Loading…
x
Reference in New Issue
Block a user