mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24189 from jakevdp:average-doc
PiperOrigin-RevId: 683660910
This commit is contained in:
commit
54101771d7
@ -32,7 +32,7 @@ from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy.util import (
|
||||
_broadcast_to, check_arraylike, _complex_elem_type,
|
||||
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
|
||||
promote_dtypes_inexact, promote_dtypes_numeric, _where)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
|
||||
from jax._src.util import (
|
||||
@ -700,9 +700,8 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
An array of the mean along the given axis.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.sum`: Compute the sum of array elements over a given axis.
|
||||
- :func:`jax.numpy.max`: Compute the maximum of array elements over given axis.
|
||||
- :func:`jax.numpy.min`: Compute the minimum of array elements over given axis.
|
||||
- :func:`jax.numpy.average`: Compute the weighted average of array elements
|
||||
- :func:`jax.numpy.sum`: Compute the sum of array elements.
|
||||
|
||||
Examples:
|
||||
By default, the mean is computed along all the axes.
|
||||
@ -782,9 +781,59 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, *
|
||||
@overload
|
||||
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
|
||||
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ...
|
||||
@implements(np.average)
|
||||
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
|
||||
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]:
|
||||
"""Compute the weighed average.
|
||||
|
||||
JAX Implementation of :func:`numpy.average`.
|
||||
|
||||
Args:
|
||||
a: array to be averaged
|
||||
axis: an optional integer or sequence of integers specifying the axis along which
|
||||
the mean to be computed. If not specified, mean is computed along all the axes.
|
||||
weights: an optional array of weights for a weighted average. Must be
|
||||
broadcast-compatible with ``a``.
|
||||
returned: If False (default) then return only the average. If True then return both
|
||||
the average and the normalization factor (i.e. the sum of weights).
|
||||
keepdims: If True, reduced axes are left in the result with size 1. If False (default)
|
||||
then reduced axes are squeezed out.
|
||||
|
||||
Returns:
|
||||
An array ``average`` or tuple of arrays ``(average, normalization)`` if
|
||||
``returned`` is True.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.mean`: unweighted mean.
|
||||
|
||||
Examples:
|
||||
Simple average:
|
||||
|
||||
>>> x = jnp.array([1, 2, 3, 2, 4])
|
||||
>>> jnp.average(x)
|
||||
Array(2.4, dtype=float32)
|
||||
|
||||
Weighted average:
|
||||
|
||||
>>> weights = jnp.array([2, 1, 3, 2, 2])
|
||||
>>> jnp.average(x, weights=weights)
|
||||
Array(2.5, dtype=float32)
|
||||
|
||||
Use ``returned=True`` to optionally return the normalization, i.e. the
|
||||
sum of weights:
|
||||
|
||||
>>> jnp.average(x, returned=True)
|
||||
(Array(2.4, dtype=float32), Array(5., dtype=float32))
|
||||
>>> jnp.average(x, weights=weights, returned=True)
|
||||
(Array(2.5, dtype=float32), Array(10., dtype=float32))
|
||||
|
||||
Weighted average along a specified axis:
|
||||
|
||||
>>> x = jnp.array([[8, 2, 7],
|
||||
... [3, 6, 4]])
|
||||
>>> weights = jnp.array([1, 2, 3])
|
||||
>>> jnp.average(x, weights=weights, axis=1)
|
||||
Array([5.5, 4.5], dtype=float32)
|
||||
"""
|
||||
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user