rocm_jax/jax/_src/numpy/reductions.py
2024-12-05 09:27:19 -08:00

2797 lines
106 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import builtins
from collections.abc import Callable, Sequence
from functools import partial
import math
import operator
from typing import overload, Any, Literal, Protocol, Union
import warnings
import numpy as np
import jax
from jax import lax
from jax._src import api
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src.numpy.util import (
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
set_module, NumpyComplexWarning)
export = set_module('jax.numpy')
_all = builtins.all
_lax_const = lax_internal._const
Axis = Union[int, Sequence[int], None]
def _isscalar(element: Any) -> bool:
if hasattr(element, '__jax_array__'):
element = element.__jax_array__()
return dtypes.is_python_scalar(element) or np.isscalar(element)
def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
# simplified version of jnp.moveaxis() for local use.
check_arraylike("moveaxis", a)
a = lax_internal.asarray(a)
source = _canonicalize_axis(source, np.ndim(a))
destination = _canonicalize_axis(destination, np.ndim(a))
perm = [i for i in range(np.ndim(a)) if i != source]
perm.insert(destination, source)
return lax.transpose(a, perm)
def _upcast_f16(dtype: DTypeLike) -> DType:
if np.dtype(dtype) in [np.float16, dtypes.bfloat16]:
return np.dtype('float32')
return np.dtype(dtype)
def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
# default dtype as defined by dtypes.int_ or dtypes.uint.
if dtypes.issubdtype(dtype, np.bool_):
return dtypes.int_
elif dtypes.issubdtype(dtype, np.unsignedinteger):
if np.iinfo(dtype).bits < np.iinfo(dtypes.uint).bits:
return dtypes.uint
elif dtypes.issubdtype(dtype, np.integer):
if np.iinfo(dtype).bits < np.iinfo(dtypes.int_).bits:
return dtypes.int_
return dtype
def check_where(name: str, where: ArrayLike | None) -> Array | None:
if where is None:
return where
check_arraylike(name, where)
where_arr = lax_internal.asarray(where)
if where_arr.dtype != bool:
# Deprecation added 2024-12-05
deprecations.warn(
'jax-numpy-reduction-non-boolean-where',
f"jnp.{name}: where must be None or a boolean array; got dtype={where_arr.dtype}.",
stacklevel=2)
return where_arr.astype(bool)
return where_arr
ReductionOp = Callable[[Any, Any], Any]
def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
*, has_identity: bool = True,
preproc: Callable[[ArrayLike], ArrayLike] | None = None,
bool_op: ReductionOp | None = None,
upcast_f16_for_computation: bool = False,
axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where_: ArrayLike | None = None,
parallel_reduce: Callable[..., Array] | None = None,
promote_integers: bool = False) -> Array:
bool_op = bool_op or op
# Note: we must accept out=None as an argument, because numpy reductions delegate to
# object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
# exists, passing along all its arguments.
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
check_arraylike(name, a)
where_ = check_where(name, where_)
dtypes.check_user_dtype_supported(dtype, name)
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
if initial is None and not has_identity and where_ is not None:
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
f"where mask one has to specify 'initial'")
a = a if isinstance(a, Array) else lax_internal.asarray(a)
a = preproc(a) if preproc else a
pos_dims, dims = _reduction_dims(a, axis)
if initial is None and not has_identity:
shape = np.shape(a)
if not _all(shape[d] >= 1 for d in pos_dims):
raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
result_dtype = dtype or dtypes.dtype(a)
if dtype is None and promote_integers:
result_dtype = _promote_integer_dtype(result_dtype)
result_dtype = dtypes.canonicalize_dtype(result_dtype)
if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
computation_dtype = _upcast_f16(result_dtype)
else:
computation_dtype = result_dtype
a = lax.convert_element_type(a, computation_dtype)
op = op if computation_dtype != np.bool_ else bool_op
# NB: in XLA, init_val must be an identity for the op, so the user-specified
# initial value must be applied afterward.
init_val = _reduction_init_val(a, init_val)
if where_ is not None:
a = _where(where_, a, init_val)
if pos_dims is not dims:
if parallel_reduce is None:
raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
result = parallel_reduce(a, dims)
else:
result = lax.reduce(a, init_val, op, dims)
if initial is not None:
initial_arr = lax.convert_element_type(initial, lax_internal.asarray(a).dtype)
if initial_arr.shape != ():
raise ValueError("initial value must be a scalar. "
f"Got array of shape {initial_arr.shape}")
result = op(initial_arr, result)
if keepdims:
result = lax.expand_dims(result, pos_dims)
return lax.convert_element_type(result, dtype or result_dtype)
def _canonicalize_axis_allow_named(x, rank):
return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)
def _reduction_dims(a: ArrayLike, axis: Axis):
if axis is None:
return (tuple(range(np.ndim(a))),) * 2
elif not isinstance(axis, (np.ndarray, tuple, list)):
axis = (axis,) # type: ignore[assignment]
canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
for x in axis) # type: ignore[union-attr]
if len(canon_axis) != len(set(canon_axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
if len(canon_pos_axis) != len(canon_axis):
return canon_pos_axis, canon_axis
else:
return canon_axis, canon_axis
def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
# This function uses np.* functions because lax pattern matches against the
# specific concrete values of the reduction inputs.
a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
if a_dtype == 'bool':
return np.array(init_val > 0, dtype=a_dtype)
if (np.isinf(init_val) and dtypes.issubdtype(a_dtype, np.floating)
and not dtypes.supports_inf(a_dtype)):
init_val = np.array(dtypes.finfo(a_dtype).min if np.isneginf(init_val)
else dtypes.finfo(a_dtype).max, dtype=a_dtype)
try:
return np.array(init_val, dtype=a_dtype)
except OverflowError:
assert dtypes.issubdtype(a_dtype, np.integer)
sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
def _cast_to_bool(operand: ArrayLike) -> Array:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NumpyComplexWarning)
return lax.convert_element_type(operand, np.bool_)
def _cast_to_numeric(operand: ArrayLike) -> Array:
return promote_dtypes_numeric(operand)[0]
def _require_integer(operand: ArrayLike) -> Array:
arr = lax_internal.asarray(operand)
if not dtypes.isdtype(arr, ("bool", "integral")):
raise ValueError(f"integer argument required; got dtype={arr.dtype}")
return arr
def _ensure_optional_axes(x: Axis) -> Axis:
def force(x):
if x is None:
return None
try:
return operator.index(x)
except TypeError:
return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
return core.concrete_or_error(
force, x, "The axis argument must be known statically.")
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)
@export
def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None, promote_integers: bool = True) -> Array:
r"""Sum of the elements of the array over a given axis.
JAX implementation of :func:`numpy.sum`.
Args:
a: Input array.
axis: int or array, default=None. Axis along which the sum to be computed.
If None, the sum is computed along all the axes.
dtype: The type of the output array. Default=None.
out: Unused by JAX
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
initial: int or array, Default=None. Initial value for the sum.
where: int or array, default=None. The elements to be used in the sum. Array
should be broadcast compatible to the input.
promote_integers : bool, default=True. If True, then integer inputs will be
promoted to the widest available integer dtype, following numpy's behavior.
If False, the result will have the same dtype as the input.
``promote_integers`` is ignored if ``dtype`` is specified.
Returns:
An array of the sum along the given axis.
See also:
- :func:`jax.numpy.prod`: Compute the product of array elements over a given
axis.
- :func:`jax.numpy.max`: Compute the maximum of array elements over given axis.
- :func:`jax.numpy.min`: Compute the minimum of array elements over given axis.
Examples:
By default, the sum is computed along all the axes.
>>> x = jnp.array([[1, 3, 4, 2],
... [5, 2, 6, 3],
... [8, 1, 3, 9]])
>>> jnp.sum(x)
Array(47, dtype=int32)
If ``axis=1``, the sum is computed along axis 1.
>>> jnp.sum(x, axis=1)
Array([10, 16, 21], dtype=int32)
If ``keepdims=True``, ``ndim`` of the output is equal to that of the input.
>>> jnp.sum(x, axis=1, keepdims=True)
Array([[10],
[16],
[21]], dtype=int32)
To include only specific elements in the sum, you can use ``where``.
>>> where=jnp.array([[0, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.sum(x, axis=1, keepdims=True, where=where)
Array([[ 4],
[ 9],
[12]], dtype=int32)
>>> where=jnp.array([[False],
... [False],
... [False]])
>>> jnp.sum(x, axis=0, keepdims=True, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, promote_integers=promote_integers)
@export
def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
r"""Return product of the array elements over a given axis.
JAX implementation of :func:`numpy.prod`.
Args:
a: Input array.
axis: int or array, default=None. Axis along which the product to be computed.
If None, the product is computed along all the axes.
dtype: The type of the output array. Default=None.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
initial: int or array, Default=None. Initial value for the product.
where: int or array, default=None. The elements to be used in the product.
Array should be broadcast compatible to the input.
promote_integers : bool, default=True. If True, then integer inputs will be
promoted to the widest available integer dtype, following numpy's behavior.
If False, the result will have the same dtype as the input.
``promote_integers`` is ignored if ``dtype`` is specified.
out: Unused by JAX.
Returns:
An array of the product along the given axis.
See also:
- :func:`jax.numpy.sum`: Compute the sum of array elements over a given axis.
- :func:`jax.numpy.max`: Compute the maximum of array elements over given axis.
- :func:`jax.numpy.min`: Compute the minimum of array elements over given axis.
Examples:
By default, ``jnp.prod`` computes along all the axes.
>>> x = jnp.array([[1, 3, 4, 2],
... [5, 2, 1, 3],
... [2, 1, 3, 1]])
>>> jnp.prod(x)
Array(4320, dtype=int32)
If ``axis=1``, product is computed along axis 1.
>>> jnp.prod(x, axis=1)
Array([24, 30, 6], dtype=int32)
If ``keepdims=True``, ``ndim`` of the output is equal to that of the input.
>>> jnp.prod(x, axis=1, keepdims=True)
Array([[24],
[30],
[ 6]], dtype=int32)
To include only specific elements in the sum, you can use a``where``.
>>> where=jnp.array([[1, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.prod(x, axis=1, keepdims=True, where=where)
Array([[4],
[3],
[6]], dtype=int32)
>>> where = jnp.array([[False],
... [False],
... [False]])
>>> jnp.prod(x, axis=1, keepdims=True, where=where)
Array([[1],
[1],
[1]], dtype=int32)
"""
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
out=out, keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
return _reduction(a, "max", lax.max, -np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmax)
@export
def max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
r"""Return the maximum of the array elements along a given axis.
JAX implementation of :func:`numpy.max`.
Args:
a: Input array.
axis: int or array, default=None. Axis along which the maximum to be computed.
If None, the maximum is computed along all the axes.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
initial: int or array, default=None. Initial value for the maximum.
where: int or array of boolean dtype, default=None. The elements to be used
in the maximum. Array should be broadcast compatible to the input.
``initial`` must be specified when ``where`` is used.
out: Unused by JAX.
Returns:
An array of maximum values along the given axis.
See also:
- :func:`jax.numpy.min`: Compute the minimum of array elements along a given
axis.
- :func:`jax.numpy.sum`: Compute the sum of array elements along a given axis.
- :func:`jax.numpy.prod`: Compute the product of array elements along a given
axis.
Examples:
By default, ``jnp.max`` computes the maximum of elements along all the axes.
>>> x = jnp.array([[9, 3, 4, 5],
... [5, 2, 7, 4],
... [8, 1, 3, 6]])
>>> jnp.max(x)
Array(9, dtype=int32)
If ``axis=1``, the maximum will be computed along axis 1.
>>> jnp.max(x, axis=1)
Array([9, 7, 8], dtype=int32)
If ``keepdims=True``, ``ndim`` of the output will be same of that of the input.
>>> jnp.max(x, axis=1, keepdims=True)
Array([[9],
[7],
[8]], dtype=int32)
To include only specific elements in computing the maximum, you can use
``where``. It can either have same dimension as input
>>> where=jnp.array([[0, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.max(x, axis=1, keepdims=True, initial=0, where=where)
Array([[4],
[7],
[8]], dtype=int32)
or must be broadcast compatible with input.
>>> where = jnp.array([[False],
... [False],
... [False]])
>>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
return _reduction(a, "min", lax.min, np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmin)
@export
def min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
r"""Return the minimum of array elements along a given axis.
JAX implementation of :func:`numpy.min`.
Args:
a: Input array.
axis: int or array, default=None. Axis along which the minimum to be computed.
If None, the minimum is computed along all the axes.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
initial: int or array, Default=None. Initial value for the minimum.
where: int or array, default=None. The elements to be used in the minimum.
Array should be broadcast compatible to the input. ``initial`` must be
specified when ``where`` is used.
out: Unused by JAX.
Returns:
An array of minimum values along the given axis.
See also:
- :func:`jax.numpy.max`: Compute the maximum of array elements along a given
axis.
- :func:`jax.numpy.sum`: Compute the sum of array elements along a given axis.
- :func:`jax.numpy.prod`: Compute the product of array elements along a given
axis.
Examples:
By default, the minimum is computed along all the axes.
>>> x = jnp.array([[2, 5, 1, 6],
... [3, -7, -2, 4],
... [8, -4, 1, -3]])
>>> jnp.min(x)
Array(-7, dtype=int32)
If ``axis=1``, the minimum is computed along axis 1.
>>> jnp.min(x, axis=1)
Array([ 1, -7, -4], dtype=int32)
If ``keepdims=True``, ``ndim`` of the output will be same of that of the input.
>>> jnp.min(x, axis=1, keepdims=True)
Array([[ 1],
[-7],
[-4]], dtype=int32)
To include only specific elements in computing the minimum, you can use
``where``. ``where`` can either have same dimension as input.
>>> where=jnp.array([[1, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.min(x, axis=1, keepdims=True, initial=0, where=where)
Array([[ 0],
[-2],
[-4]], dtype=int32)
or must be broadcast compatible with input.
>>> where = jnp.array([[False],
... [False],
... [False]])
>>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
return _reduction(a, "all", lax.bitwise_and, True, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@export
def all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
r"""Test whether all array elements along a given axis evaluate to True.
JAX implementation of :func:`numpy.all`.
Args:
a: Input array.
axis: int or array, default=None. Axis along which to be tested. If None,
tests along all the axes.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
where: int or array of boolean dtype, default=None. The elements to be used
in the test. Array should be broadcast compatible to the input.
out: Unused by JAX.
Returns:
An array of boolean values.
Examples:
By default, ``jnp.all`` tests for True values along all the axes.
>>> x = jnp.array([[True, True, True, False],
... [True, False, True, False],
... [True, True, False, False]])
>>> jnp.all(x)
Array(False, dtype=bool)
If ``axis=0``, tests for True values along axis 0.
>>> jnp.all(x, axis=0)
Array([ True, False, False, False], dtype=bool)
If ``keepdims=True``, ``ndim`` of the output will be same of that of the input.
>>> jnp.all(x, axis=0, keepdims=True)
Array([[ True, False, False, False]], dtype=bool)
To include specific elements in testing for True values, you can use a``where``.
>>> where=jnp.array([[1, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.all(x, axis=0, keepdims=True, where=where)
Array([[ True, True, False, False]], dtype=bool)
"""
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
return _reduction(a, "any", lax.bitwise_or, False, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@export
def any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
r"""Test whether any of the array elements along a given axis evaluate to True.
JAX implementation of :func:`numpy.any`.
Args:
a: Input array.
axis: int or array, default=None. Axis along which to be tested. If None,
tests along all the axes.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
where: int or array of boolean dtype, default=None. The elements to be used
in the test. Array should be broadcast compatible to the input.
out: Unused by JAX.
Returns:
An array of boolean values.
Examples:
By default, ``jnp.any`` tests along all the axes.
>>> x = jnp.array([[True, True, True, False],
... [True, False, True, False],
... [True, True, False, False]])
>>> jnp.any(x)
Array(True, dtype=bool)
If ``axis=0``, tests along axis 0.
>>> jnp.any(x, axis=0)
Array([ True, True, True, False], dtype=bool)
If ``keepdims=True``, ``ndim`` of the output will be same of that of the input.
>>> jnp.any(x, axis=0, keepdims=True)
Array([[ True, True, True, False]], dtype=bool)
To include specific elements in testing for True values, you can use a``where``.
>>> where=jnp.array([[1, 0, 1, 0],
... [0, 1, 0, 1],
... [1, 0, 1, 0]], dtype=bool)
>>> jnp.any(x, axis=0, keepdims=True, where=where)
Array([[ True, False, True, False]], dtype=bool)
"""
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
arr = lax_internal.asarray(a)
init_val = np.array(-1, dtype=dtype or arr.dtype)
return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
return _reduction(a, name="reduce_bitwise_or", op=lax.bitwise_or, init_val=0, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
return _reduction(a, name="reduce_bitwise_xor", op=lax.bitwise_xor, init_val=0, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
return _reduction(a, name="reduce_logical_and", op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
return _reduction(a, name="reduce_logical_or", op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
return _reduction(a, name="reduce_logical_xor", op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
"""Compute log(sum(exp(a))) while avoiding precision loss."""
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.")
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce")
check_arraylike("logsumexp", a)
where = check_where("logsumexp", where)
a_arr, = promote_dtypes_inexact(a)
pos_dims, dims = _reduction_dims(a_arr, axis)
amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf)
amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype)))
sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where)
result = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype))
return result if initial is None else lax.logaddexp(initial, result)
def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
"""Compute log2(sum(2 ** a)) via logsumexp."""
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.")
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce")
check_arraylike("logsumexp2", a)
where = check_where("logsumexp2", where)
ln2 = float(np.log(2))
if initial is not None:
initial *= ln2
return _logsumexp(a * ln2, axis=axis, dtype=dtype, keepdims=keepdims,
where=where, initial=initial) / ln2
@export
def amin(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
"""Alias of :func:`jax.numpy.min`."""
return min(a, axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@export
def amax(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
"""Alias of :func:`jax.numpy.max`."""
return max(a, axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
def _axis_size(a: ArrayLike, axis: int | Sequence[int]):
if not isinstance(axis, (tuple, list)):
axis_seq: Sequence[int] = (axis,) # type: ignore[assignment]
else:
axis_seq = axis
size = 1
a_shape = np.shape(a)
for a in axis_seq:
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
return size
@export
def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
r"""Return the mean of array elements along a given axis.
JAX implementation of :func:`numpy.mean`.
Args:
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
mean to be computed. If None, mean is computed along all the axes.
dtype: The type of the output array. Default=None.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
where: optional, boolean array, default=None. The elements to be used in the
mean. Array should be broadcast compatible to the input.
out: Unused by JAX.
Returns:
An array of the mean along the given axis.
See also:
- :func:`jax.numpy.average`: Compute the weighted average of array elements
- :func:`jax.numpy.sum`: Compute the sum of array elements.
Examples:
By default, the mean is computed along all the axes.
>>> x = jnp.array([[1, 3, 4, 2],
... [5, 2, 6, 3],
... [8, 1, 2, 9]])
>>> jnp.mean(x)
Array(3.8333335, dtype=float32)
If ``axis=1``, the mean is computed along axis 1.
>>> jnp.mean(x, axis=1)
Array([2.5, 4. , 5. ], dtype=float32)
If ``keepdims=True``, ``ndim`` of the output is equal to that of the input.
>>> jnp.mean(x, axis=1, keepdims=True)
Array([[2.5],
[4. ],
[5. ]], dtype=float32)
To use only specific elements of ``x`` to compute the mean, you can use
``where``.
>>> where = jnp.array([[1, 0, 1, 0],
... [0, 1, 0, 1],
... [1, 1, 0, 1]], dtype=bool)
>>> jnp.mean(x, axis=1, keepdims=True, where=where)
Array([[2.5],
[2.5],
[6. ]], dtype=float32)
"""
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
upcast_f16_for_computation: bool = True,
where: ArrayLike | None = None) -> Array:
check_arraylike("mean", a)
where = check_where("mean", where)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
if dtype is None:
result_dtype = dtypes.to_inexact_dtype(dtypes.dtype(a, canonicalize=True))
else:
dtypes.check_user_dtype_supported(dtype, "mean")
result_dtype = dtypes.canonicalize_dtype(dtype)
if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
computation_dtype = _upcast_f16(result_dtype)
else:
computation_dtype = result_dtype
if where is None:
if axis is None:
normalizer = core.dimension_as_value(np.size(a))
else:
normalizer = core.dimension_as_value(_axis_size(a, axis))
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis,
dtype=computation_dtype, keepdims=keepdims)
return lax.div(
sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where),
lax.convert_element_type(normalizer, computation_dtype)
).astype(result_dtype)
@overload
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: Literal[False] = False, keepdims: bool = False) -> Array: ...
@overload
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, *,
returned: Literal[True], keepdims: bool = False) -> Array: ...
@overload
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ...
@export
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]:
"""Compute the weighed average.
JAX Implementation of :func:`numpy.average`.
Args:
a: array to be averaged
axis: an optional integer or sequence of integers specifying the axis along which
the mean to be computed. If not specified, mean is computed along all the axes.
weights: an optional array of weights for a weighted average. Must be
broadcast-compatible with ``a``.
returned: If False (default) then return only the average. If True then return both
the average and the normalization factor (i.e. the sum of weights).
keepdims: If True, reduced axes are left in the result with size 1. If False (default)
then reduced axes are squeezed out.
Returns:
An array ``average`` or tuple of arrays ``(average, normalization)`` if
``returned`` is True.
See also:
- :func:`jax.numpy.mean`: unweighted mean.
Examples:
Simple average:
>>> x = jnp.array([1, 2, 3, 2, 4])
>>> jnp.average(x)
Array(2.4, dtype=float32)
Weighted average:
>>> weights = jnp.array([2, 1, 3, 2, 2])
>>> jnp.average(x, weights=weights)
Array(2.5, dtype=float32)
Use ``returned=True`` to optionally return the normalization, i.e. the
sum of weights:
>>> jnp.average(x, returned=True)
(Array(2.4, dtype=float32), Array(5., dtype=float32))
>>> jnp.average(x, weights=weights, returned=True)
(Array(2.5, dtype=float32), Array(10., dtype=float32))
Weighted average along a specified axis:
>>> x = jnp.array([[8, 2, 7],
... [3, 6, 4]])
>>> weights = jnp.array([1, 2, 3])
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)
"""
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]:
if weights is None: # Treat all weights as 1
check_arraylike("average", a)
a, = promote_dtypes_inexact(a)
avg = mean(a, axis=axis, keepdims=keepdims)
if axis is None:
weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype)
elif isinstance(axis, tuple):
weights_sum = lax.full_like(avg, math.prod(core.dimension_as_value(a.shape[d]) for d in axis))
else:
weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index]
else:
check_arraylike("average", a, weights)
a, weights = promote_dtypes_inexact(a, weights)
a_shape = np.shape(a)
a_ndim = len(a_shape)
weights_shape = np.shape(weights)
if axis is None:
pass
elif isinstance(axis, tuple):
axis = tuple(_canonicalize_axis(d, a_ndim) for d in axis)
else:
axis = _canonicalize_axis(axis, a_ndim)
if a_shape != weights_shape:
# Make sure the dimensions work out
if len(weights_shape) != 1:
raise ValueError("1D weights expected when shapes of a and "
"weights differ.")
if axis is None:
raise ValueError("Axis must be specified when shapes of a and "
"weights differ.")
elif isinstance(axis, tuple):
raise ValueError("Single axis expected when shapes of a and weights differ")
elif not core.definitely_equal(weights_shape[0], a_shape[axis]):
raise ValueError("Length of weights not "
"compatible with specified axis.")
weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
weights = _moveaxis(weights, -1, axis)
weights_sum = sum(weights, axis=axis, keepdims=keepdims)
avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum
if returned:
if avg.shape != weights_sum.shape:
weights_sum = _broadcast_to(weights_sum, avg.shape)
return avg, weights_sum
return avg
@export
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
r"""Compute the variance along a given axis.
JAX implementation of :func:`numpy.var`.
Args:
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
variance is computed. If None, variance is computed along all the axes.
dtype: The type of the output array. Default=None.
ddof: int, default=0. Degrees of freedom. The divisor in the variance computation
is ``N-ddof``, ``N`` is number of elements along given axis.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
where: optional, boolean array, default=None. The elements to be used in the
variance. Array should be broadcast compatible to the input.
correction: int or float, default=None. Alternative name for ``ddof``.
Both ddof and correction can't be provided simultaneously.
out: Unused by JAX.
Returns:
An array of the variance along the given axis.
See also:
- :func:`jax.numpy.mean`: Compute the mean of array elements over a given axis.
- :func:`jax.numpy.std`: Compute the standard deviation of array elements over
given axis.
- :func:`jax.numpy.nanvar`: Compute the variance along a given axis, ignoring
NaNs values.
- :func:`jax.numpy.nanstd`: Computed the standard deviation of a given axis,
ignoring NaN values.
Examples:
By default, ``jnp.var`` computes the variance along all axes.
>>> x = jnp.array([[1, 3, 4, 2],
... [5, 2, 6, 3],
... [8, 4, 2, 9]])
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.var(x)
Array(5.74, dtype=float32)
If ``axis=1``, variance is computed along axis 1.
>>> jnp.var(x, axis=1)
Array([1.25 , 2.5 , 8.1875], dtype=float32)
To preserve the dimensions of input, you can set ``keepdims=True``.
>>> jnp.var(x, axis=1, keepdims=True)
Array([[1.25 ],
[2.5 ],
[8.1875]], dtype=float32)
If ``ddof=1``:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.var(x, axis=1, keepdims=True, ddof=1))
[[ 1.67]
[ 3.33]
[10.92]]
To include specific elements of the array to compute variance, you can use
``where``.
>>> where = jnp.array([[1, 0, 1, 0],
... [0, 1, 1, 0],
... [1, 1, 1, 0]], dtype=bool)
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.var(x, axis=1, keepdims=True, where=where))
[[2.25]
[4. ]
[6.22]]
"""
if correction is None:
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("var", a)
where = check_where("var", where)
dtypes.check_user_dtype_supported(dtype, "var")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
a = lax_internal.asarray(a).astype(computation_dtype)
a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
centered = lax.sub(a, a_mean)
if dtypes.issubdtype(computation_dtype, np.complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
computation_dtype = centered.dtype # avoid casting to complex below.
else:
centered = lax.square(centered)
if where is None:
if axis is None:
normalizer = core.dimension_as_value(np.size(a))
else:
normalizer = core.dimension_as_value(_axis_size(a, axis))
normalizer = lax.convert_element_type(normalizer, computation_dtype)
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis,
dtype=computation_dtype, keepdims=keepdims)
normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype))
result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where)
result = lax.div(result, normalizer).astype(dtype)
with jax.debug_nans(False):
result = _where(normalizer > 0, result, np.nan)
return result
def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DType, DType]:
if dtype:
if (not dtypes.issubdtype(dtype, np.complexfloating) and
dtypes.issubdtype(a_dtype, np.complexfloating)):
msg = ("jax.numpy.var does not yet support real dtype parameters when "
"computing the variance of an array of complex values. The "
"semantics of numpy.var seem unclear in this case. Please comment "
"on https://github.com/jax-ml/jax/issues/2283 if this behavior is "
"important to you.")
raise ValueError(msg)
computation_dtype = dtype
else:
if not dtypes.issubdtype(a_dtype, np.inexact):
dtype = dtypes.to_inexact_dtype(a_dtype)
computation_dtype = dtype
else:
dtype = _complex_elem_type(a_dtype)
computation_dtype = a_dtype
return _upcast_f16(computation_dtype), np.dtype(dtype)
@export
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
r"""Compute the standard deviation along a given axis.
JAX implementation of :func:`numpy.std`.
Args:
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
standard deviation is computed. If None, standard deviaiton is computed
along all the axes.
dtype: The type of the output array. Default=None.
ddof: int, default=0. Degrees of freedom. The divisor in the standard deviation
computation is ``N-ddof``, ``N`` is number of elements along given axis.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
where: optional, boolean array, default=None. The elements to be used in the
standard deviation. Array should be broadcast compatible to the input.
correction: int or float, default=None. Alternative name for ``ddof``.
Both ddof and correction can't be provided simultaneously.
out: Unused by JAX.
Returns:
An array of the standard deviation along the given axis.
See also:
- :func:`jax.numpy.var`: Compute the variance of array elements over given
axis.
- :func:`jax.numpy.mean`: Compute the mean of array elements over a given axis.
- :func:`jax.numpy.nanvar`: Compute the variance along a given axis, ignoring
NaNs values.
- :func:`jax.numpy.nanstd`: Computed the standard deviation of a given axis,
ignoring NaN values.
Examples:
By default, ``jnp.std`` computes the standard deviation along all axes.
>>> x = jnp.array([[1, 3, 4, 2],
... [4, 2, 5, 3],
... [5, 4, 2, 3]])
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.std(x)
Array(1.21, dtype=float32)
If ``axis=0``, computes along axis 0.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.std(x, axis=0))
[1.7 0.82 1.25 0.47]
To preserve the dimensions of input, you can set ``keepdims=True``.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.std(x, axis=0, keepdims=True))
[[1.7 0.82 1.25 0.47]]
If ``ddof=1``:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.std(x, axis=0, keepdims=True, ddof=1))
[[2.08 1. 1.53 0.58]]
To include specific elements of the array to compute standard deviation, you
can use ``where``.
>>> where = jnp.array([[1, 0, 1, 0],
... [0, 1, 0, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.std(x, axis=0, keepdims=True, where=where)
Array([[2., 1., 1., 0.]], dtype=float32)
"""
if correction is None:
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("std", a)
where = check_where("std", where)
dtypes.check_user_dtype_supported(dtype, "std")
if dtype is not None and not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where))
@export
def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
r"""Return the peak-to-peak range along a given axis.
JAX implementation of :func:`numpy.ptp`.
Args:
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
range is computed. If None, the range is computed on the flattened array.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
out: Unused by JAX.
Returns:
An array with the range of elements along specified axis of input.
Examples:
By default, ``jnp.ptp`` computes the range along all axes.
>>> x = jnp.array([[1, 3, 5, 2],
... [4, 6, 8, 1],
... [7, 9, 3, 4]])
>>> jnp.ptp(x)
Array(8, dtype=int32)
If ``axis=1``, computes the range along axis 1.
>>> jnp.ptp(x, axis=1)
Array([4, 7, 6], dtype=int32)
To preserve the dimensions of input, you can set ``keepdims=True``.
>>> jnp.ptp(x, axis=1, keepdims=True)
Array([[4],
[7],
[6]], dtype=int32)
"""
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
check_arraylike("ptp", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.")
x = amax(a, axis=axis, keepdims=keepdims)
y = amin(a, axis=axis, keepdims=keepdims)
return lax.sub(x, y)
@export
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def count_nonzero(a: ArrayLike, axis: Axis = None,
keepdims: bool = False) -> Array:
r"""Return the number of nonzero elements along a given axis.
JAX implementation of :func:`numpy.count_nonzero`.
Args:
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
number of nonzeros are counted. If None, counts within the flattened array.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
Returns:
An array with number of nonzeros elements along specified axis of the input.
Examples:
By default, ``jnp.count_nonzero`` counts the nonzero values along all axes.
>>> x = jnp.array([[1, 0, 0, 0],
... [0, 0, 1, 0],
... [1, 1, 1, 0]])
>>> jnp.count_nonzero(x)
Array(5, dtype=int32)
If ``axis=1``, counts along axis 1.
>>> jnp.count_nonzero(x, axis=1)
Array([1, 1, 3], dtype=int32)
To preserve the dimensions of input, you can set ``keepdims=True``.
>>> jnp.count_nonzero(x, axis=1, keepdims=True)
Array([[1],
[1],
[3]], dtype=int32)
"""
check_arraylike("count_nonzero", a)
return sum(lax.ne(a, _lax_const(a, 0)), axis=axis,
dtype=dtypes.canonicalize_dtype(int), keepdims=keepdims)
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
init_val: ArrayLike, nan_if_all_nan: bool,
axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None,
**kwargs) -> Array:
check_arraylike(name, a)
where = check_where(name, where)
if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs)
out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a),
axis=axis, keepdims=keepdims, where=where, **kwargs)
if nan_if_all_nan:
return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
_lax_const(a, np.nan), out)
else:
return out
@export
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
r"""Return the minimum of the array elements along a given axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanmin`.
Args:
a: Input array.
axis: int or sequence of ints, default=None. Axis along which the minimum is
computed. If None, the minimum is computed along the flattened array.
keepdims: bool, default=False. If True, reduced axes are left in the result
with size 1.
initial: int or array, default=None. Initial value for the minimum.
where: array of boolean dtype, default=None. The elements to be used in the
minimum. Array should be broadcast compatible to the input. ``initial``
must be specified when ``where`` is used.
out: Unused by JAX.
Returns:
An array of minimum values along the given axis, ignoring NaNs. If all values
are NaNs along the given axis, returns ``nan``.
See also:
- :func:`jax.numpy.nanmax`: Compute the maximum of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nansum`: Compute the sum of array elements along a given
axis, ignoring NaNs.
- :func:`jax.numpy.nanprod`: Compute the product of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nanmean`: Compute the mean of array elements along a given
axis, ignoring NaNs.
Examples:
By default, ``jnp.nanmin`` computes the minimum of elements along the flattened
array.
>>> nan = jnp.nan
>>> x = jnp.array([[1, nan, 4, 5],
... [nan, -2, nan, -4],
... [2, 1, 3, nan]])
>>> jnp.nanmin(x)
Array(-4., dtype=float32)
If ``axis=1``, the maximum will be computed along axis 1.
>>> jnp.nanmin(x, axis=1)
Array([ 1., -4., 1.], dtype=float32)
If ``keepdims=True``, ``ndim`` of the output will be same of that of the input.
>>> jnp.nanmin(x, axis=1, keepdims=True)
Array([[ 1.],
[-4.],
[ 1.]], dtype=float32)
To include only specific elements in computing the maximum, you can use
``where``. It can either have same dimension as input
>>> where=jnp.array([[0, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.nanmin(x, axis=1, keepdims=True, initial=0, where=where)
Array([[ 0.],
[-4.],
[ 0.]], dtype=float32)
or must be broadcast compatible with input.
>>> where = jnp.array([[False],
... [True],
... [False]])
>>> jnp.nanmin(x, axis=0, keepdims=True, initial=0, where=where)
Array([[ 0., -2., 0., -4.]], dtype=float32)
"""
return _nan_reduction(a, 'nanmin', min, np.inf, nan_if_all_nan=initial is None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@export
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
r"""Return the maximum of the array elements along a given axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanmax`.
Args:
a: Input array.
axis: int or sequence of ints, default=None. Axis along which the maximum is
computed. If None, the maximum is computed along the flattened array.
keepdims: bool, default=False. If True, reduced axes are left in the result
with size 1.
initial: int or array, default=None. Initial value for the maximum.
where: array of boolean dtype, default=None. The elements to be used in the
maximum. Array should be broadcast compatible to the input. ``initial``
must be specified when ``where`` is used.
out: Unused by JAX.
Returns:
An array of maximum values along the given axis, ignoring NaNs. If all values
are NaNs along the given axis, returns ``nan``.
See also:
- :func:`jax.numpy.nanmin`: Compute the minimum of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nansum`: Compute the sum of array elements along a given
axis, ignoring NaNs.
- :func:`jax.numpy.nanprod`: Compute the product of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nanmean`: Compute the mean of array elements along a given
axis, ignoring NaNs.
Examples:
By default, ``jnp.nanmax`` computes the maximum of elements along the flattened
array.
>>> nan = jnp.nan
>>> x = jnp.array([[8, nan, 4, 6],
... [nan, -2, nan, -4],
... [-2, 1, 7, nan]])
>>> jnp.nanmax(x)
Array(8., dtype=float32)
If ``axis=1``, the maximum will be computed along axis 1.
>>> jnp.nanmax(x, axis=1)
Array([ 8., -2., 7.], dtype=float32)
If ``keepdims=True``, ``ndim`` of the output will be same of that of the input.
>>> jnp.nanmax(x, axis=1, keepdims=True)
Array([[ 8.],
[-2.],
[ 7.]], dtype=float32)
To include only specific elements in computing the maximum, you can use
``where``. It can either have same dimension as input
>>> where=jnp.array([[0, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.nanmax(x, axis=1, keepdims=True, initial=0, where=where)
Array([[4.],
[0.],
[7.]], dtype=float32)
or must be broadcast compatible with input.
>>> where = jnp.array([[True],
... [False],
... [False]])
>>> jnp.nanmax(x, axis=0, keepdims=True, initial=0, where=where)
Array([[8., 0., 4., 6.]], dtype=float32)
"""
return _nan_reduction(a, 'nanmax', max, -np.inf, nan_if_all_nan=initial is None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
r"""Return the sum of the array elements along a given axis, ignoring NaNs.
JAX implementation of :func:`numpy.nansum`.
Args:
a: Input array.
axis: int or sequence of ints, default=None. Axis along which the sum is
computed. If None, the sum is computed along the flattened array.
dtype: The type of the output array. Default=None.
keepdims: bool, default=False. If True, reduced axes are left in the result
with size 1.
initial: int or array, default=None. Initial value for the sum.
where: array of boolean dtype, default=None. The elements to be used in the
sum. Array should be broadcast compatible to the input.
out: Unused by JAX.
Returns:
An array containing the sum of array elements along the given axis, ignoring
NaNs. If all elements along the given axis are NaNs, returns 0.
See also:
- :func:`jax.numpy.nanmin`: Compute the minimum of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nanmax`: Compute the maximum of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nanprod`: Compute the product of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nanmean`: Compute the mean of array elements along a given
axis, ignoring NaNs.
Examples:
By default, ``jnp.nansum`` computes the sum of elements along the flattened
array.
>>> nan = jnp.nan
>>> x = jnp.array([[3, nan, 4, 5],
... [nan, -2, nan, 7],
... [2, 1, 6, nan]])
>>> jnp.nansum(x)
Array(26., dtype=float32)
If ``axis=1``, the sum will be computed along axis 1.
>>> jnp.nansum(x, axis=1)
Array([12., 5., 9.], dtype=float32)
If ``keepdims=True``, ``ndim`` of the output will be same of that of the input.
>>> jnp.nansum(x, axis=1, keepdims=True)
Array([[12.],
[ 5.],
[ 9.]], dtype=float32)
To include only specific elements in computing the sum, you can use ``where``.
>>> where=jnp.array([[1, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.nansum(x, axis=1, keepdims=True, where=where)
Array([[7.],
[7.],
[9.]], dtype=float32)
If ``where`` is ``False`` at all elements, ``jnp.nansum`` returns 0 along
the given axis.
>>> where = jnp.array([[False],
... [False],
... [False]])
>>> jnp.nansum(x, axis=0, keepdims=True, where=where)
Array([[0., 0., 0., 0.]], dtype=float32)
"""
dtypes.check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
r"""Return the product of the array elements along a given axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanprod`.
Args:
a: Input array.
axis: int or sequence of ints, default=None. Axis along which the product is
computed. If None, the product is computed along the flattened array.
dtype: The type of the output array. Default=None.
keepdims: bool, default=False. If True, reduced axes are left in the result
with size 1.
initial: int or array, default=None. Initial value for the product.
where: array of boolean dtype, default=None. The elements to be used in the
product. Array should be broadcast compatible to the input.
out: Unused by JAX.
Returns:
An array containing the product of array elements along the given axis,
ignoring NaNs. If all elements along the given axis are NaNs, returns 1.
See also:
- :func:`jax.numpy.nanmin`: Compute the minimum of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nanmax`: Compute the maximum of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nansum`: Compute the sum of array elements along a given
axis, ignoring NaNs.
- :func:`jax.numpy.nanmean`: Compute the mean of array elements along a given
axis, ignoring NaNs.
Examples:
By default, ``jnp.nanprod`` computes the product of elements along the flattened
array.
>>> nan = jnp.nan
>>> x = jnp.array([[nan, 3, 4, nan],
... [5, nan, 1, 3],
... [2, 1, nan, 1]])
>>> jnp.nanprod(x)
Array(360., dtype=float32)
If ``axis=1``, the product will be computed along axis 1.
>>> jnp.nanprod(x, axis=1)
Array([12., 15., 2.], dtype=float32)
If ``keepdims=True``, ``ndim`` of the output will be same of that of the input.
>>> jnp.nanprod(x, axis=1, keepdims=True)
Array([[12.],
[15.],
[ 2.]], dtype=float32)
To include only specific elements in computing the maximum, you can use ``where``.
>>> where=jnp.array([[1, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.nanprod(x, axis=1, keepdims=True, where=where)
Array([[4.],
[3.],
[2.]], dtype=float32)
If ``where`` is ``False`` at all elements, ``jnp.nanprod`` returns 1 along
the given axis.
>>> where = jnp.array([[False],
... [False],
... [False]])
>>> jnp.nanprod(x, axis=0, keepdims=True, where=where)
Array([[1., 1., 1., 1.]], dtype=float32)
"""
dtypes.check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, where: ArrayLike | None = None) -> Array:
r"""Return the mean of the array elements along a given axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanmean`.
Args:
a: Input array.
axis: int or sequence of ints, default=None. Axis along which the mean is
computed. If None, the mean is computed along the flattened array.
dtype: The type of the output array. Default=None.
keepdims: bool, default=False. If True, reduced axes are left in the result
with size 1.
where: array of boolean dtype, default=None. The elements to be used in
computing mean. Array should be broadcast compatible to the input.
out: Unused by JAX.
Returns:
An array containing the mean of array elements along the given axis, ignoring
NaNs. If all elements along the given axis are NaNs, returns ``nan``.
See also:
- :func:`jax.numpy.nanmin`: Compute the minimum of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nanmax`: Compute the maximum of array elements along a
given axis, ignoring NaNs.
- :func:`jax.numpy.nansum`: Compute the sum of array elements along a given
axis, ignoring NaNs.
- :func:`jax.numpy.nanprod`: Compute the product of array elements along a
given axis, ignoring NaNs.
Examples:
By default, ``jnp.nanmean`` computes the mean of elements along the flattened
array.
>>> nan = jnp.nan
>>> x = jnp.array([[2, nan, 4, 3],
... [nan, -2, nan, 9],
... [4, -7, 6, nan]])
>>> jnp.nanmean(x)
Array(2.375, dtype=float32)
If ``axis=1``, mean will be computed along axis 1.
>>> jnp.nanmean(x, axis=1)
Array([3. , 3.5, 1. ], dtype=float32)
If ``keepdims=True``, ``ndim`` of the output will be same of that of the input.
>>> jnp.nanmean(x, axis=1, keepdims=True)
Array([[3. ],
[3.5],
[1. ]], dtype=float32)
``where`` can be used to include only specific elements in computing the mean.
>>> where = jnp.array([[1, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 0, 1]], dtype=bool)
>>> jnp.nanmean(x, axis=1, keepdims=True, where=where)
Array([[ 3. ],
[ 9. ],
[-1.5]], dtype=float32)
If ``where`` is ``False`` at all elements, ``jnp.nanmean`` returns ``nan``
along the given axis.
>>> where = jnp.array([[False],
... [False],
... [False]])
>>> jnp.nanmean(x, axis=0, keepdims=True, where=where)
Array([[nan, nan, nan, nan]], dtype=float32)
"""
check_arraylike("nanmean", a)
where = check_where("nanmean", where)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer):
return mean(a, axis, dtype, out, keepdims, where=where)
if dtype is None:
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a, canonicalize=True))
else:
dtypes.check_user_dtype_supported(dtype, "mean")
dtype = dtypes.canonicalize_dtype(dtype)
nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a))
normalizer = sum(nan_mask, axis=axis, dtype=dtype, keepdims=keepdims, where=where)
td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer)
return td
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
where: ArrayLike | None = None) -> Array:
r"""Compute the variance of array elements along a given axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanvar`.
Args:
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
variance is computed. If None, variance is computed along flattened array.
dtype: The type of the output array. Default=None.
ddof: int, default=0. Degrees of freedom. The divisor in the variance computation
is ``N-ddof``, ``N`` is number of elements along given axis.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
where: optional, boolean array, default=None. The elements to be used in the
variance. Array should be broadcast compatible to the input.
out: Unused by JAX.
Returns:
An array containing the variance of array elements along specified axis. If
all elements along the given axis are NaNs, returns ``nan``.
See also:
- :func:`jax.numpy.nanmean`: Compute the mean of array elements over a given
axis, ignoring NaNs.
- :func:`jax.numpy.nanstd`: Computed the standard deviation of a given axis,
ignoring NaNs.
- :func:`jax.numpy.var`: Compute the variance of array elements along a given
axis.
Examples:
By default, ``jnp.nanvar`` computes the variance along all axes.
>>> nan = jnp.nan
>>> x = jnp.array([[1, nan, 4, 3],
... [nan, 2, nan, 9],
... [4, 8, 6, nan]])
>>> jnp.nanvar(x)
Array(6.984375, dtype=float32)
If ``axis=1``, variance is computed along axis 1.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.nanvar(x, axis=1))
[ 1.56 12.25 2.67]
To preserve the dimensions of input, you can set ``keepdims=True``.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.nanvar(x, axis=1, keepdims=True))
[[ 1.56]
[12.25]
[ 2.67]]
If ``ddof=1``:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.nanvar(x, axis=1, keepdims=True, ddof=1))
[[ 2.33]
[24.5 ]
[ 4. ]]
To include specific elements of the array to compute variance, you can use
``where``.
>>> where = jnp.array([[1, 0, 1, 0],
... [0, 1, 1, 0],
... [1, 1, 0, 1]], dtype=bool)
>>> jnp.nanvar(x, axis=1, keepdims=True, where=where)
Array([[2.25],
[0. ],
[4. ]], dtype=float32)
"""
check_arraylike("nanvar", a)
where = check_where("nanvar", where)
dtypes.check_user_dtype_supported(dtype, "nanvar")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
a = lax_internal.asarray(a).astype(computation_dtype)
a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients.
if dtypes.issubdtype(centered.dtype, np.complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
else:
centered = lax.square(centered)
normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
axis=axis, keepdims=keepdims, where=where)
normalizer = normalizer - ddof
normalizer_mask = lax.le(normalizer, lax_internal._zero(normalizer))
result = sum(centered, axis, keepdims=keepdims, where=where)
result = _where(normalizer_mask, np.nan, result)
divisor = _where(normalizer_mask, 1, normalizer)
result = lax.div(result, lax.convert_element_type(divisor, result.dtype))
return lax.convert_element_type(result, dtype)
@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
where: ArrayLike | None = None) -> Array:
r"""Compute the standard deviation along a given axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanstd`.
Args:
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
standard deviation is computed. If None, standard deviaiton is computed
along flattened array.
dtype: The type of the output array. Default=None.
ddof: int, default=0. Degrees of freedom. The divisor in the standard deviation
computation is ``N-ddof``, ``N`` is number of elements along given axis.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
where: optional, boolean array, default=None. The elements to be used in the
standard deviation. Array should be broadcast compatible to the input.
out: Unused by JAX.
Returns:
An array containing the standard deviation of array elements along the given
axis. If all elements along the given axis are NaNs, returns ``nan``.
See also:
- :func:`jax.numpy.nanmean`: Compute the mean of array elements over a given
axis, ignoring NaNs.
- :func:`jax.numpy.nanvar`: Compute the variance along the given axis, ignoring
NaNs values.
- :func:`jax.numpy.std`: Computed the standard deviation along the given axis.
Examples:
By default, ``jnp.nanstd`` computes the standard deviation along flattened array.
>>> nan = jnp.nan
>>> x = jnp.array([[3, nan, 4, 5],
... [nan, 2, nan, 7],
... [2, 1, 6, nan]])
>>> jnp.nanstd(x)
Array(1.9843135, dtype=float32)
If ``axis=0``, computes standard deviation along axis 0.
>>> jnp.nanstd(x, axis=0)
Array([0.5, 0.5, 1. , 1. ], dtype=float32)
To preserve the dimensions of input, you can set ``keepdims=True``.
>>> jnp.nanstd(x, axis=0, keepdims=True)
Array([[0.5, 0.5, 1. , 1. ]], dtype=float32)
If ``ddof=1``:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.nanstd(x, axis=0, keepdims=True, ddof=1))
[[0.71 0.71 1.41 1.41]]
To include specific elements of the array to compute standard deviation, you
can use ``where``.
>>> where=jnp.array([[1, 0, 1, 0],
... [0, 1, 0, 1],
... [1, 1, 0, 1]], dtype=bool)
>>> jnp.nanstd(x, axis=0, keepdims=True, where=where)
Array([[0.5, 0.5, 0. , 0. ]], dtype=float32)
"""
check_arraylike("nanstd", a)
where = check_where("nanstd", where)
dtypes.check_user_dtype_supported(dtype, "nanstd")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
class CumulativeReduction(Protocol):
def __call__(self, a: ArrayLike, axis: Axis = None,
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
def _cumulative_reduction(
name: str, reduction: Callable[..., Array],
a: ArrayLike, axis: int | None, dtype: DTypeLike | None, out: None = None,
fill_nan: bool = False, fill_value: ArrayLike = 0,
promote_integers: bool = False) -> Array:
"""Helper function for implementing cumulative reductions."""
check_arraylike(name, a)
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported")
dtypes.check_user_dtype_supported(dtype, name)
if axis is None or _isscalar(a):
a = lax.reshape(a, (np.size(a),))
if axis is None:
axis = 0
a_shape = list(np.shape(a))
num_dims = len(a_shape)
axis = _canonicalize_axis(axis, num_dims)
if fill_nan:
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)
a_type: DType = dtypes.dtype(a)
result_type: DTypeLike = dtypes.dtype(dtype or a)
if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_):
result_type = _promote_integer_dtype(result_type)
result_type = dtypes.canonicalize_dtype(result_type)
if a_type != np.bool_ and dtype == np.bool_:
a = lax_internal.asarray(a).astype(np.bool_)
a = lax.convert_element_type(a, result_type)
result = reduction(a, axis)
# We downcast to boolean because we accumulate in integer types
if dtype is not None and dtypes.issubdtype(dtype, np.bool_):
result = lax.convert_element_type(result, np.bool_)
return result
@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def cumsum(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Cumulative sum of elements along an axis.
JAX implementation of :func:`numpy.cumsum`.
Args:
a: N-dimensional array to be accumulated.
axis: integer axis along which to accumulate. If None (default), then
array will be flattened and accumulated along the flattened axis.
dtype: optionally specify the dtype of the output. If not specified,
then the output dtype will match the input dtype.
out: unused by JAX
Returns:
An array containing the accumulated sum along the given axis.
See also:
- :func:`jax.numpy.cumulative_sum`: cumulative sum via the array API standard.
- :meth:`jax.numpy.add.accumulate`: cumulative sum via ufunc methods.
- :func:`jax.numpy.nancumsum`: cumulative sum ignoring NaN values.
- :func:`jax.numpy.sum`: sum along axis
Examples:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.cumsum(x) # flattened cumulative sum
Array([ 1, 3, 6, 10, 15, 21], dtype=int32)
>>> jnp.cumsum(x, axis=1) # cumulative sum along axis 1
Array([[ 1, 3, 6],
[ 4, 9, 15]], dtype=int32)
"""
return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out)
@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def cumprod(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Cumulative product of elements along an axis.
JAX implementation of :func:`numpy.cumprod`.
Args:
a: N-dimensional array to be accumulated.
axis: integer axis along which to accumulate. If None (default), then
array will be flattened and accumulated along the flattened axis.
dtype: optionally specify the dtype of the output. If not specified,
then the output dtype will match the input dtype.
out: unused by JAX
Returns:
An array containing the accumulated product along the given axis.
See also:
- :meth:`jax.numpy.multiply.accumulate`: cumulative product via ufunc methods.
- :func:`jax.numpy.nancumprod`: cumulative product ignoring NaN values.
- :func:`jax.numpy.prod`: product along axis
Examples:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.cumprod(x) # flattened cumulative product
Array([ 1, 2, 6, 24, 120, 720], dtype=int32)
>>> jnp.cumprod(x, axis=1) # cumulative product along axis 1
Array([[ 1, 2, 6],
[ 4, 20, 120]], dtype=int32)
"""
return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out)
@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def nancumsum(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Cumulative sum of elements along an axis, ignoring NaN values.
JAX implementation of :func:`numpy.nancumsum`.
Args:
a: N-dimensional array to be accumulated.
axis: integer axis along which to accumulate. If None (default), then
array will be flattened and accumulated along the flattened axis.
dtype: optionally specify the dtype of the output. If not specified,
then the output dtype will match the input dtype.
out: unused by JAX
Returns:
An array containing the accumulated sum along the given axis.
See also:
- :func:`jax.numpy.cumsum`: cumulative sum without ignoring NaN values.
- :func:`jax.numpy.cumulative_sum`: cumulative sum via the array API standard.
- :meth:`jax.numpy.add.accumulate`: cumulative sum via ufunc methods.
- :func:`jax.numpy.sum`: sum along axis
Examples:
>>> x = jnp.array([[1., 2., jnp.nan],
... [4., jnp.nan, 6.]])
The standard cumulative sum will propagate NaN values:
>>> jnp.cumsum(x)
Array([ 1., 3., nan, nan, nan, nan], dtype=float32)
:func:`~jax.numpy.nancumsum` will ignore NaN values, effectively replacing
them with zeros:
>>> jnp.nancumsum(x)
Array([ 1., 3., 3., 7., 7., 13.], dtype=float32)
Cumulative sum along axis 1:
>>> jnp.nancumsum(x, axis=1)
Array([[ 1., 3., 3.],
[ 4., 4., 10.]], dtype=float32)
"""
return _cumulative_reduction("nancumsum", lax.cumsum, a, axis, dtype, out,
fill_nan=True, fill_value=0)
@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def nancumprod(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Cumulative product of elements along an axis, ignoring NaN values.
JAX implementation of :func:`numpy.nancumprod`.
Args:
a: N-dimensional array to be accumulated.
axis: integer axis along which to accumulate. If None (default), then
array will be flattened and accumulated along the flattened axis.
dtype: optionally specify the dtype of the output. If not specified,
then the output dtype will match the input dtype.
out: unused by JAX
Returns:
An array containing the accumulated product along the given axis.
See also:
- :func:`jax.numpy.cumprod`: cumulative product without ignoring NaN values.
- :meth:`jax.numpy.multiply.accumulate`: cumulative product via ufunc methods.
- :func:`jax.numpy.prod`: product along axis
Examples:
>>> x = jnp.array([[1., 2., jnp.nan],
... [4., jnp.nan, 6.]])
The standard cumulative product will propagate NaN values:
>>> jnp.cumprod(x)
Array([ 1., 2., nan, nan, nan, nan], dtype=float32)
:func:`~jax.numpy.nancumprod` will ignore NaN values, effectively replacing
them with ones:
>>> jnp.nancumprod(x)
Array([ 1., 2., 2., 8., 8., 48.], dtype=float32)
Cumulative product along axis 1:
>>> jnp.nancumprod(x, axis=1)
Array([[ 1., 2., 2.],
[ 4., 4., 24.]], dtype=float32)
"""
return _cumulative_reduction("nancumprod", lax.cumprod, a, axis, dtype, out,
fill_nan=True, fill_value=1)
@partial(api.jit, static_argnames=('axis', 'dtype'))
def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Utility function to compute cumsum with integer promotion."""
return _cumulative_reduction("_cumsum_with_promotion", lax.cumsum,
a, axis, dtype, out, promote_integers=True)
@export
def cumulative_sum(
x: ArrayLike, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
include_initial: bool = False) -> Array:
"""Cumulative sum along the axis of an array.
JAX implementation of :func:`numpy.cumulative_sum`.
Args:
x: N-dimensional array
axis: integer axis along which to accumulate. If ``x`` is one-dimensional,
this argument is optional and defaults to zero.
dtype: optional dtype of the output.
include_initial: if True, then include the initial value in the cumulative
sum. Default is False.
Returns:
An array containing the accumulated values.
See Also:
- :func:`jax.numpy.cumsum`: alternative API for cumulative sum.
- :func:`jax.numpy.nancumsum`: cumulative sum while ignoring NaN values.
- :func:`jax.numpy.add.accumulate`: cumulative sum via the ufunc API.
Examples:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.cumulative_sum(x, axis=1)
Array([[ 1, 3, 6],
[ 4, 9, 15]], dtype=int32)
>>> jnp.cumulative_sum(x, axis=1, include_initial=True)
Array([[ 0, 1, 3, 6],
[ 0, 4, 9, 15]], dtype=int32)
"""
check_arraylike("cumulative_sum", x)
x = lax_internal.asarray(x)
if x.ndim == 0:
raise ValueError(
"The input must be non-scalar to take a cumulative sum, however a "
"scalar value or scalar array was given."
)
if axis is None:
axis = 0
if x.ndim > 1:
raise ValueError(
f"The input array has rank {x.ndim}, however axis was not set to an "
"explicit value. The axis argument is only optional for one-dimensional "
"arrays.")
axis = _canonicalize_axis(axis, x.ndim)
dtypes.check_user_dtype_supported(dtype)
out = _cumsum_with_promotion(x, axis=axis, dtype=dtype)
if include_initial:
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
out = lax_internal.concatenate(
[lax_internal.full(zeros_shape, 0, dtype=out.dtype), out],
dimension=axis)
return out
@export
def cumulative_prod(
x: ArrayLike, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
include_initial: bool = False) -> Array:
"""Cumulative product along the axis of an array.
JAX implementation of :func:`numpy.cumulative_prod`.
Args:
x: N-dimensional array
axis: integer axis along which to accumulate. If ``x`` is one-dimensional,
this argument is optional and defaults to zero.
dtype: optional dtype of the output.
include_initial: if True, then include the initial value in the cumulative
product. Default is False.
Returns:
An array containing the accumulated values.
See Also:
- :func:`jax.numpy.cumprod`: alternative API for cumulative product.
- :func:`jax.numpy.nancumprod`: cumulative product while ignoring NaN values.
- :func:`jax.numpy.multiply.accumulate`: cumulative product via the ufunc API.
Examples:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.cumulative_prod(x, axis=1)
Array([[ 1, 2, 6],
[ 4, 20, 120]], dtype=int32)
>>> jnp.cumulative_prod(x, axis=1, include_initial=True)
Array([[ 1, 1, 2, 6],
[ 1, 4, 20, 120]], dtype=int32)
"""
check_arraylike("cumulative_prod", x)
x = lax_internal.asarray(x)
if x.ndim == 0:
raise ValueError(
"The input must be non-scalar to take a cumulative product, however a "
"scalar value or scalar array was given."
)
if axis is None:
axis = 0
if x.ndim > 1:
raise ValueError(
f"The input array has rank {x.ndim}, however axis was not set to an "
"explicit value. The axis argument is only optional for one-dimensional "
"arrays.")
axis = _canonicalize_axis(axis, x.ndim)
dtypes.check_user_dtype_supported(dtype)
out = _cumulative_reduction("cumulative_prod", lax.cumprod, x, axis, dtype)
if include_initial:
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
out = lax_internal.concatenate(
[lax_internal.full(zeros_shape, 1, dtype=out.dtype), out],
dimension=axis)
return out
# Quantiles
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
"""Compute the quantile of the data along the specified axis.
JAX implementation of :func:`numpy.quantile`.
Args:
a: N-dimensional array input.
q: scalar or 1-dimensional array specifying the desired quantiles. ``q``
should contain floating-point values between ``0.0`` and ``1.0``.
axis: optional axis or tuple of axes along which to compute the quantile
out: not implemented by JAX; will error if not None
overwrite_input: not implemented by JAX; will error if not False
method: specify the interpolation method to use. Options are one of
``["linear", "lower", "higher", "midpoint", "nearest"]``.
default is ``linear``.
keepdims: if True, then the returned array will have the same number of
dimensions as the input. Default is False.
interpolation: deprecated alias of the ``method`` argument. Will result
in a :class:`DeprecationWarning` if used.
Returns:
An array containing the specified quantiles along the specified axes.
See also:
- :func:`jax.numpy.nanquantile`: compute the quantile while ignoring NaNs
- :func:`jax.numpy.percentile`: compute the percentile (0-100)
Examples:
Computing the median and quartiles of an array, with linear interpolation:
>>> x = jnp.arange(10)
>>> q = jnp.array([0.25, 0.5, 0.75])
>>> jnp.quantile(x, q)
Array([2.25, 4.5 , 6.75], dtype=float32)
Computing the quartiles using nearest-value interpolation:
>>> jnp.quantile(x, q, method='nearest')
Array([2., 4., 7.], dtype=float32)
"""
check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
raise ValueError("jax.numpy.quantile does not support overwrite_input=True "
"or out != None")
if not isinstance(interpolation, DeprecatedArg):
deprecations.warn(
"jax-numpy-quantile-interpolation",
("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead."), stacklevel=2)
method = interpolation
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False)
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
"""Compute the quantile of the data along the specified axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanquantile`.
Args:
a: N-dimensional array input.
q: scalar or 1-dimensional array specifying the desired quantiles. ``q``
should contain floating-point values between ``0.0`` and ``1.0``.
axis: optional axis or tuple of axes along which to compute the quantile
out: not implemented by JAX; will error if not None
overwrite_input: not implemented by JAX; will error if not False
method: specify the interpolation method to use. Options are one of
``["linear", "lower", "higher", "midpoint", "nearest"]``.
default is ``linear``.
keepdims: if True, then the returned array will have the same number of
dimensions as the input. Default is False.
interpolation: deprecated alias of the ``method`` argument. Will result
in a :class:`DeprecationWarning` if used.
Returns:
An array containing the specified quantiles along the specified axes.
See also:
- :func:`jax.numpy.quantile`: compute the quantile without ignoring nans
- :func:`jax.numpy.nanpercentile`: compute the percentile (0-100)
Examples:
Computing the median and quartiles of a 1D array:
>>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6])
>>> q = jnp.array([0.25, 0.5, 0.75])
Because of the NaN value, :func:`jax.numpy.quantile` returns all NaNs,
while :func:`~jax.numpy.nanquantile` ignores them:
>>> jnp.quantile(x, q)
Array([nan, nan, nan], dtype=float32)
>>> jnp.nanquantile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)
"""
check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if not isinstance(interpolation, DeprecatedArg):
deprecations.warn(
"jax-numpy-quantile-interpolation",
("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead."), stacklevel=2)
method = interpolation
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True)
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
method: str, keepdims: bool, squash_nans: bool) -> Array:
if method not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'")
a, = promote_dtypes_inexact(a)
keepdim = []
if dtypes.issubdtype(a.dtype, np.complexfloating):
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
if axis is None:
if keepdims:
keepdim = [1] * a.ndim
a = a.ravel()
axis = 0
elif isinstance(axis, tuple):
keepdim = list(a.shape)
nd = a.ndim
axis = tuple(_canonicalize_axis(ax, nd) for ax in axis)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis')
for ax in axis:
keepdim[ax] = 1
keep = set(range(nd)) - set(axis)
# prepare permutation
dimensions = list(range(nd))
for i, s in enumerate(sorted(keep)):
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis)
touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis)
a = lax.reshape(a, do_not_touch_shape + (math.prod(touch_shape),), dimensions)
axis = _canonicalize_axis(-1, a.ndim)
else:
axis = _canonicalize_axis(axis, a.ndim)
q_shape = q.shape
q_ndim = q.ndim
if q_ndim > 1:
raise ValueError(f"q must be have rank <= 1, got shape {q.shape}")
a_shape = a.shape
if squash_nans:
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
a = lax.sort(a, dimension=axis)
counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
shape_after_reduction = counts.shape
q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
counts = lax.expand_dims(counts, tuple(range(q_ndim)))
q = lax.mul(q, lax.sub(counts, _lax_const(q, 1)))
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
low = lax.convert_element_type(low, int)
high = lax.convert_element_type(high, int)
out_shape = q_shape + shape_after_reduction
index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim)
for dim in range(len(shape_after_reduction))]
if keepdims:
index[axis] = low
else:
index.insert(axis, low)
low_value = a[tuple(index)]
index[axis] = high
high_value = a[tuple(index)]
else:
with jax.debug_nans(False):
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
q = lax.mul(q, n - 1)
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.clamp(_lax_const(low, 0), low, n - 1)
high = lax.clamp(_lax_const(high, 0), high, n - 1)
low = lax.convert_element_type(low, int)
high = lax.convert_element_type(high, int)
slice_sizes = list(a_shape)
slice_sizes[axis] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(range(
q_ndim,
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
collapsed_slice_dims=() if keepdims else (axis,),
start_index_map=(axis,))
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
if q_ndim == 1:
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
broadcast_dimensions=(0,))
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
broadcast_dimensions=(0,))
if method == "linear":
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
lax.mul(high_value.astype(q.dtype), high_weight))
elif method == "lower":
result = low_value
elif method == "higher":
result = high_value
elif method == "nearest":
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
result = lax.select(pred, low_value, high_value)
elif method == "midpoint":
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
else:
raise ValueError(f"{method=!r} not recognized")
if keepdims and keepdim:
if q_ndim > 0:
keepdim = [np.shape(q)[0], *keepdim]
result = result.reshape(keepdim)
return lax.convert_element_type(result, a.dtype)
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
"""Compute the percentile of the data along the specified axis.
JAX implementation of :func:`numpy.percentile`.
Args:
a: N-dimensional array input.
q: scalar or 1-dimensional array specifying the desired quantiles. ``q``
should contain integer or floating point values between ``0`` and ``100``.
axis: optional axis or tuple of axes along which to compute the quantile
out: not implemented by JAX; will error if not None
overwrite_input: not implemented by JAX; will error if not False
method: specify the interpolation method to use. Options are one of
``["linear", "lower", "higher", "midpoint", "nearest"]``.
default is ``linear``.
keepdims: if True, then the returned array will have the same number of
dimensions as the input. Default is False.
interpolation: deprecated alias of the ``method`` argument. Will result
in a :class:`DeprecationWarning` if used.
Returns:
An array containing the specified percentiles along the specified axes.
See also:
- :func:`jax.numpy.quantile`: compute the quantile (0.0-1.0)
- :func:`jax.numpy.nanpercentile`: compute the percentile while ignoring NaNs
Examples:
Computing the median and quartiles of a 1D array:
>>> x = jnp.array([0, 1, 2, 3, 4, 5, 6])
>>> q = jnp.array([25, 50, 75])
>>> jnp.percentile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)
Computing the same percentiles with nearest rather than linear interpolation:
>>> jnp.percentile(x, q, method='nearest')
Array([1., 3., 4.], dtype=float32)
"""
check_arraylike("percentile", a, q)
q, = promote_dtypes_inexact(q)
if not isinstance(interpolation, DeprecatedArg):
deprecations.warn(
"jax-numpy-quantile-interpolation",
("The interpolation= argument to 'percentile' is deprecated. "
"Use 'method=' instead."), stacklevel=2)
method = interpolation
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
"""Compute the percentile of the data along the specified axis, ignoring NaN values.
JAX implementation of :func:`numpy.nanpercentile`.
Args:
a: N-dimensional array input.
q: scalar or 1-dimensional array specifying the desired quantiles. ``q``
should contain integer or floating point values between ``0`` and ``100``.
axis: optional axis or tuple of axes along which to compute the quantile
out: not implemented by JAX; will error if not None
overwrite_input: not implemented by JAX; will error if not False
method: specify the interpolation method to use. Options are one of
``["linear", "lower", "higher", "midpoint", "nearest"]``.
default is ``linear``.
keepdims: if True, then the returned array will have the same number of
dimensions as the input. Default is False.
interpolation: deprecated alias of the ``method`` argument. Will result
in a :class:`DeprecationWarning` if used.
Returns:
An array containing the specified percentiles along the specified axes.
See also:
- :func:`jax.numpy.nanquantile`: compute the nan-aware quantile (0.0-1.0)
- :func:`jax.numpy.percentile`: compute the percentile without special
handling of NaNs.
Examples:
Computing the median and quartiles of a 1D array:
>>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6])
>>> q = jnp.array([25, 50, 75])
Because of the NaN value, :func:`jax.numpy.percentile` returns all NaNs,
while :func:`~jax.numpy.nanpercentile` ignores them:
>>> jnp.percentile(x, q)
Array([nan, nan, nan], dtype=float32)
>>> jnp.nanpercentile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)
"""
check_arraylike("nanpercentile", a, q)
q, = promote_dtypes_inexact(q)
q = q / 100
if not isinstance(interpolation, DeprecatedArg):
deprecations.warn(
"jax-numpy-quantile-interpolation",
("The interpolation= argument to 'nanpercentile' is deprecated. "
"Use 'method=' instead."), stacklevel=2)
method = interpolation
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
r"""Return the median of array elements along a given axis.
JAX implementation of :func:`numpy.median`.
Args:
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
median to be computed. If None, median is computed for the flattened array.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
out: Unused by JAX.
overwrite_input: Unused by JAX.
Returns:
An array of the median along the given axis.
See also:
- :func:`jax.numpy.mean`: Compute the mean of array elements over a given axis.
- :func:`jax.numpy.max`: Compute the maximum of array elements over given axis.
- :func:`jax.numpy.min`: Compute the minimum of array elements over given axis.
Examples:
By default, the median is computed for the flattened array.
>>> x = jnp.array([[2, 4, 7, 1],
... [3, 5, 9, 2],
... [6, 1, 8, 3]])
>>> jnp.median(x)
Array(3.5, dtype=float32)
If ``axis=1``, the median is computed along axis 1.
>>> jnp.median(x, axis=1)
Array([3. , 4. , 4.5], dtype=float32)
If ``keepdims=True``, ``ndim`` of the output is equal to that of the input.
>>> jnp.median(x, axis=1, keepdims=True)
Array([[3. ],
[4. ],
[4.5]], dtype=float32)
"""
check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, method='midpoint')
@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
r"""Return the median of array elements along a given axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanmedian`.
Args:
a: input array.
axis: optional, int or sequence of ints, default=None. Axis along which the
median to be computed. If None, median is computed for the flattened array.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
out: Unused by JAX.
overwrite_input: Unused by JAX.
Returns:
An array containing the median along the given axis, ignoring NaNs. If all
elements along the given axis are NaNs, returns ``nan``.
See also:
- :func:`jax.numpy.nanmean`: Compute the mean of array elements over a given
axis, ignoring NaNs.
- :func:`jax.numpy.nanmax`: Compute the maximum of array elements over given
axis, ignoring NaNs.
- :func:`jax.numpy.nanmin`: Compute the minimum of array elements over given
axis, ignoring NaNs.
Examples:
By default, the median is computed for the flattened array.
>>> nan = jnp.nan
>>> x = jnp.array([[2, nan, 7, nan],
... [nan, 5, 9, 2],
... [6, 1, nan, 3]])
>>> jnp.nanmedian(x)
Array(4., dtype=float32)
If ``axis=1``, the median is computed along axis 1.
>>> jnp.nanmedian(x, axis=1)
Array([4.5, 5. , 3. ], dtype=float32)
If ``keepdims=True``, ``ndim`` of the output is equal to that of the input.
>>> jnp.nanmedian(x, axis=1, keepdims=True)
Array([[4.5],
[5. ],
[3. ]], dtype=float32)
"""
check_arraylike("nanmedian", a)
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
method='midpoint')