doc: add note about f16 casting in jnp.mean

This commit is contained in:
Jake VanderPlas 2025-02-05 10:46:07 -08:00
parent da0827b7f1
commit 9b402ecdb7

View File

@ -812,7 +812,9 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
mean to be computed. If None, mean is computed along all the axes.
dtype: The type of the output array. Default=None.
dtype: The type of the output array. If None (default) then the output dtype
will be match the input dtype for floating point inputs, or be set to float32
or float64 for non-floating-point inputs.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
where: optional, boolean array, default=None. The elements to be used in the
@ -822,6 +824,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
Returns:
An array of the mean along the given axis.
Notes:
For inputs of type `float16` or `bfloat16`, the reductions will be performed at
float32 precision.
See also:
- :func:`jax.numpy.average`: Compute the weighted average of array elements
- :func:`jax.numpy.sum`: Compute the sum of array elements.