mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
doc: add note about f16 casting in jnp.mean
This commit is contained in:
parent
da0827b7f1
commit
9b402ecdb7
@ -812,7 +812,9 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
a: input array.
|
||||
axis: optional, int or sequence of ints, default=None. Axis along which the
|
||||
mean to be computed. If None, mean is computed along all the axes.
|
||||
dtype: The type of the output array. Default=None.
|
||||
dtype: The type of the output array. If None (default) then the output dtype
|
||||
will be match the input dtype for floating point inputs, or be set to float32
|
||||
or float64 for non-floating-point inputs.
|
||||
keepdims: bool, default=False. If true, reduced axes are left in the result
|
||||
with size 1.
|
||||
where: optional, boolean array, default=None. The elements to be used in the
|
||||
@ -822,6 +824,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
Returns:
|
||||
An array of the mean along the given axis.
|
||||
|
||||
Notes:
|
||||
For inputs of type `float16` or `bfloat16`, the reductions will be performed at
|
||||
float32 precision.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.average`: Compute the weighted average of array elements
|
||||
- :func:`jax.numpy.sum`: Compute the sum of array elements.
|
||||
|
Loading…
x
Reference in New Issue
Block a user