rocm_jax/jax/_src/numpy/reductions.py
Jake VanderPlas 911f745775 Make jax._src.typing.DTypeLike more strictly defined
This is in preparation for exporting this to `jax.typing.DTypeLike`. Currently this is effectively just Any, and we want to make certain it's a meaningful type before exporting.

PiperOrigin-RevId: 572260744
2023-10-10 09:01:19 -07:00

894 lines
40 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
from collections.abc import Sequence
from functools import partial
import math
import operator
from typing import overload, Any, Callable, Literal, Optional, Protocol, Union
import warnings
import numpy as np
from jax import lax
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, _wraps)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
NumpyComplexWarning)
_all = builtins.all
_lax_const = lax_internal._const
Axis = Union[None, int, Sequence[int]]
def _isscalar(element: Any) -> bool:
if hasattr(element, '__jax_array__'):
element = element.__jax_array__()
return dtypes.is_python_scalar(element) or np.isscalar(element)
def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
# simplified version of jnp.moveaxis() for local use.
check_arraylike("moveaxis", a)
a = lax_internal.asarray(a)
source = _canonicalize_axis(source, np.ndim(a))
destination = _canonicalize_axis(destination, np.ndim(a))
perm = [i for i in range(np.ndim(a)) if i != source]
perm.insert(destination, source)
return lax.transpose(a, perm)
def _upcast_f16(dtype: DTypeLike) -> DType:
if np.dtype(dtype) in [np.float16, dtypes.bfloat16]:
return np.dtype('float32')
return np.dtype(dtype)
ReductionOp = Callable[[Any, Any], Any]
def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
*, has_identity: bool = True,
preproc: Optional[Callable[[ArrayLike], ArrayLike]] = None,
bool_op: Optional[ReductionOp] = None,
upcast_f16_for_computation: bool = False,
axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where_: Optional[ArrayLike] = None,
parallel_reduce: Optional[Callable[..., Array]] = None,
promote_integers: bool = False) -> Array:
bool_op = bool_op or op
# Note: we must accept out=None as an argument, because numpy reductions delegate to
# object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
# exists, passing along all its arguments.
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
check_arraylike(name, a)
dtypes.check_user_dtype_supported(dtype, name)
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
if initial is None and not has_identity and where_ is not None:
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
f"where mask one has to specify 'initial'")
a = a if isinstance(a, Array) else lax_internal.asarray(a)
a = preproc(a) if preproc else a
pos_dims, dims = _reduction_dims(a, axis)
if initial is None and not has_identity:
shape = np.shape(a)
if not _all(shape[d] >= 1 for d in pos_dims):
raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
result_dtype = dtype or dtypes.dtype(a)
if dtype is None and promote_integers:
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
# default dtype as defined by dtypes.int_ or dtypes.uint.
if dtypes.issubdtype(result_dtype, np.bool_):
result_dtype = dtypes.int_
elif dtypes.issubdtype(result_dtype, np.unsignedinteger):
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits:
result_dtype = dtypes.uint
elif dtypes.issubdtype(result_dtype, np.integer):
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits:
result_dtype = dtypes.int_
result_dtype = dtypes.canonicalize_dtype(result_dtype)
if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
computation_dtype = _upcast_f16(result_dtype)
else:
computation_dtype = result_dtype
a = lax.convert_element_type(a, computation_dtype)
op = op if computation_dtype != np.bool_ else bool_op
# NB: in XLA, init_val must be an identity for the op, so the user-specified
# initial value must be applied afterward.
init_val = _reduction_init_val(a, init_val)
if where_ is not None:
a = _where(where_, a, init_val)
if pos_dims is not dims:
if parallel_reduce is None:
raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
result = parallel_reduce(a, dims)
else:
result = lax.reduce(a, init_val, op, dims)
if initial is not None:
initial_arr = lax.convert_element_type(initial, lax_internal.asarray(a).dtype)
if initial_arr.shape != ():
raise ValueError("initial value must be a scalar. "
f"Got array of shape {initial_arr.shape}")
result = op(initial_arr, result)
if keepdims:
result = lax.expand_dims(result, pos_dims)
return lax.convert_element_type(result, dtype or result_dtype)
def _canonicalize_axis_allow_named(x, rank):
return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)
def _reduction_dims(a: ArrayLike, axis: Axis):
if axis is None:
return (tuple(range(np.ndim(a))),) * 2
elif not isinstance(axis, (np.ndarray, tuple, list)):
axis = (axis,) # type: ignore[assignment]
canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
for x in axis) # type: ignore[union-attr]
if len(canon_axis) != len(set(canon_axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
if len(canon_pos_axis) != len(canon_axis):
return canon_pos_axis, canon_axis
else:
return canon_axis, canon_axis
def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
# This function uses np.* functions because lax pattern matches against the
# specific concrete values of the reduction inputs.
a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
if a_dtype == 'bool':
return np.array(init_val > 0, dtype=a_dtype)
try:
return np.array(init_val, dtype=a_dtype)
except OverflowError:
assert dtypes.issubdtype(a_dtype, np.integer)
sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
def _cast_to_bool(operand: ArrayLike) -> Array:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NumpyComplexWarning)
return lax.convert_element_type(operand, np.bool_)
def _cast_to_numeric(operand: ArrayLike) -> Array:
return promote_dtypes_numeric(operand)[0]
def _ensure_optional_axes(x: Axis) -> Axis:
def force(x):
if x is None:
return None
try:
return operator.index(x)
except TypeError:
return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
return core.concrete_or_error(
force, x, "The axis argument must be known statically.")
# TODO(jakevdp) change promote_integers default to False
_PROMOTE_INTEGERS_DOC = """
promote_integers : bool, default=True
If True, then integer inputs will be promoted to the widest available integer
dtype, following numpy's behavior. If False, the result will have the same dtype
as the input. ``promote_integers`` is ignored if ``dtype`` is specified.
"""
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, keepdims: bool = False,
initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)
@_wraps(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
def sum(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None, promote_integers: bool = True) -> Array:
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, keepdims: bool = False,
initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, promote_integers=promote_integers)
@_wraps(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
def prod(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, keepdims: bool = False,
initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None,
promote_integers: bool = True) -> Array:
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
out=out, keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmax)
@_wraps(np.max, skip_params=['out'])
def max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmin)
@_wraps(np.min, skip_params=['out'])
def min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array:
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@_wraps(np.all, skip_params=['out'])
def all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array:
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array:
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@_wraps(np.any, skip_params=['out'])
def any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array:
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
amin = min
amax = max
def _axis_size(a: ArrayLike, axis: Union[int, Sequence[int]]):
if not isinstance(axis, (tuple, list)):
axis_seq: Sequence[int] = (axis,) # type: ignore[assignment]
else:
axis_seq = axis
size = 1
a_shape = np.shape(a)
for a in axis_seq:
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
return size
@_wraps(np.mean, skip_params=['out'])
def mean(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _mean(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, keepdims: bool = False, *,
upcast_f16_for_computation: bool = True,
where: Optional[ArrayLike] = None) -> Array:
check_arraylike("mean", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
if dtype is None:
result_dtype = dtypes.to_inexact_dtype(dtypes.dtype(a, canonicalize=True))
else:
dtypes.check_user_dtype_supported(dtype, "mean")
result_dtype = dtypes.canonicalize_dtype(dtype)
if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
computation_dtype = _upcast_f16(result_dtype)
else:
computation_dtype = result_dtype
if where is None:
if axis is None:
normalizer = core.dimension_as_value(np.size(a))
else:
normalizer = core.dimension_as_value(_axis_size(a, axis))
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
return lax.div(
sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where),
lax.convert_element_type(normalizer, computation_dtype)
).astype(result_dtype)
@overload
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: Literal[False] = False, keepdims: bool = False) -> Array: ...
@overload
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, *,
returned: Literal[True], keepdims: bool = False) -> Array: ...
@overload
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]: ...
@_wraps(np.average)
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]:
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, tuple[Array, Array]]:
if weights is None: # Treat all weights as 1
check_arraylike("average", a)
a, = promote_dtypes_inexact(a)
avg = mean(a, axis=axis, keepdims=keepdims)
if axis is None:
weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype)
elif isinstance(axis, tuple):
weights_sum = lax.full_like(avg, math.prod(core.dimension_as_value(a.shape[d]) for d in axis))
else:
weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index]
else:
check_arraylike("average", a, weights)
a, weights = promote_dtypes_inexact(a, weights)
a_shape = np.shape(a)
a_ndim = len(a_shape)
weights_shape = np.shape(weights)
if axis is None:
pass
elif isinstance(axis, tuple):
axis = tuple(_canonicalize_axis(d, a_ndim) for d in axis)
else:
axis = _canonicalize_axis(axis, a_ndim)
if a_shape != weights_shape:
# Make sure the dimensions work out
if len(weights_shape) != 1:
raise ValueError("1D weights expected when shapes of a and "
"weights differ.")
if axis is None:
raise ValueError("Axis must be specified when shapes of a and "
"weights differ.")
elif isinstance(axis, tuple):
raise ValueError("Single axis expected when shapes of a and weights differ")
elif not core.definitely_equal(weights_shape[0], a_shape[axis]):
raise ValueError("Length of weights not "
"compatible with specified axis.")
weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
weights = _moveaxis(weights, -1, axis)
weights_sum = sum(weights, axis=axis, keepdims=keepdims)
avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum
if returned:
if avg.shape != weights_sum.shape:
weights_sum = _broadcast_to(weights_sum, avg.shape)
return avg, weights_sum
return avg
@_wraps(np.var, skip_params=['out'])
def var(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
check_arraylike("var", a)
dtypes.check_user_dtype_supported(dtype, "var")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
a = lax_internal.asarray(a).astype(computation_dtype)
a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
centered = lax.sub(a, a_mean)
if dtypes.issubdtype(computation_dtype, np.complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
computation_dtype = centered.dtype # avoid casting to complex below.
else:
centered = lax.square(centered)
if where is None:
if axis is None:
normalizer = core.dimension_as_value(np.size(a))
else:
normalizer = core.dimension_as_value(_axis_size(a, axis))
normalizer = lax.convert_element_type(normalizer, computation_dtype)
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis,
dtype=computation_dtype, keepdims=keepdims)
normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype))
result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where)
return lax.div(result, normalizer).astype(dtype)
def _var_promote_types(a_dtype: DTypeLike, dtype: Optional[DTypeLike]) -> tuple[DType, DType]:
if dtype:
if (not dtypes.issubdtype(dtype, np.complexfloating) and
dtypes.issubdtype(a_dtype, np.complexfloating)):
msg = ("jax.numpy.var does not yet support real dtype parameters when "
"computing the variance of an array of complex values. The "
"semantics of numpy.var seem unclear in this case. Please comment "
"on https://github.com/google/jax/issues/2283 if this behavior is "
"important to you.")
raise ValueError(msg)
computation_dtype = dtype
else:
if not dtypes.issubdtype(a_dtype, np.inexact):
dtype = dtypes.to_inexact_dtype(a_dtype)
computation_dtype = dtype
else:
dtype = _complex_elem_type(a_dtype)
computation_dtype = a_dtype
return _upcast_f16(computation_dtype), np.dtype(dtype)
@_wraps(np.std, skip_params=['out'])
def std(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
check_arraylike("std", a)
dtypes.check_user_dtype_supported(dtype, "std")
if dtype is not None and not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
@_wraps(np.ptp, skip_params=['out'])
def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
check_arraylike("ptp", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.")
x = amax(a, axis=axis, keepdims=keepdims)
y = amin(a, axis=axis, keepdims=keepdims)
return lax.sub(x, y)
@_wraps(np.count_nonzero)
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def count_nonzero(a: ArrayLike, axis: Axis = None,
keepdims: bool = False) -> Array:
check_arraylike("count_nonzero", a)
return sum(lax.ne(a, _lax_const(a, 0)), axis=axis,
dtype=dtypes.canonicalize_dtype(int), keepdims=keepdims)
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
init_val: ArrayLike, nan_if_all_nan: bool,
axis: Axis = None, keepdims: bool = False, **kwargs) -> Array:
check_arraylike(name, a)
if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a),
axis=axis, keepdims=keepdims, **kwargs)
if nan_if_all_nan:
return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
_lax_const(a, np.nan), out)
else:
return out
@_wraps(np.nanmin, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _nan_reduction(a, 'nanmin', min, np.inf, nan_if_all_nan=initial is None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@_wraps(np.nanmax, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _nan_reduction(a, 'nanmax', max, -np.inf, nan_if_all_nan=initial is None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@_wraps(np.nansum, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nansum(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
# Work around a sphinx documentation warning in NumPy 1.22.
if nansum.__doc__ is not None:
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
@_wraps(np.nanprod, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanprod(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
@_wraps(np.nanmean, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanmean(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None,
keepdims: bool = False, where: Optional[ArrayLike] = None) -> Array:
check_arraylike("nanmean", a)
dtypes.check_user_dtype_supported(dtype, "nanmean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer):
return mean(a, axis, dtype, out, keepdims, where=where)
if dtype is None:
dtype = dtypes.dtype(a)
nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a))
normalizer = sum(nan_mask, axis=axis, dtype=np.int32, keepdims=keepdims, where=where)
normalizer = lax.convert_element_type(normalizer, dtype)
td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer)
return td
@_wraps(np.nanvar, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanvar(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
where: Optional[ArrayLike] = None) -> Array:
check_arraylike("nanvar", a)
dtypes.check_user_dtype_supported(dtype, "nanvar")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
a = lax_internal.asarray(a).astype(computation_dtype)
a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients.
if dtypes.issubdtype(centered.dtype, np.complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
else:
centered = lax.square(centered)
normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
axis=axis, keepdims=keepdims, where=where)
normalizer = normalizer - ddof
normalizer_mask = lax.le(normalizer, lax_internal._zero(normalizer))
result = sum(centered, axis, keepdims=keepdims, where=where)
result = _where(normalizer_mask, np.nan, result)
divisor = _where(normalizer_mask, 1, normalizer)
result = lax.div(result, lax.convert_element_type(divisor, result.dtype))
return lax.convert_element_type(result, dtype)
@_wraps(np.nanstd, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanstd(a: ArrayLike, axis: Axis = None, dtype: Optional[DTypeLike] = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
where: Optional[ArrayLike] = None) -> Array:
check_arraylike("nanstd", a)
dtypes.check_user_dtype_supported(dtype, "nanstd")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
class CumulativeReduction(Protocol):
def __call__(self, a: ArrayLike, axis: Axis = None,
dtype: Optional[DTypeLike] = None, out: None = None) -> Array: ...
def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array],
fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction:
@_wraps(np_reduction, skip_params=['out'])
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
dtype: Optional[DTypeLike] = None, out: None = None) -> Array:
return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
@partial(api.jit, static_argnames=('axis', 'dtype'))
def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
dtype: Optional[DTypeLike] = None, out: None = None) -> Array:
check_arraylike(np_reduction.__name__, a)
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
f"is not supported.")
dtypes.check_user_dtype_supported(dtype, np_reduction.__name__)
if axis is None or _isscalar(a):
a = lax.reshape(a, (np.size(a),))
if axis is None:
axis = 0
a_shape = list(np.shape(a))
num_dims = len(a_shape)
axis = _canonicalize_axis(axis, num_dims)
if fill_nan:
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)
if not dtype and dtypes.dtype(a) == np.bool_:
dtype = dtypes.canonicalize_dtype(dtypes.int_)
if dtype:
a = lax.convert_element_type(a, dtype)
return reduction(a, axis)
return cumulative_reduction
cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False)
cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False)
nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
fill_nan=True, fill_value=0)
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
fill_nan=True, fill_value=1)
# Quantiles
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False)
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True)
def _quantile(a: Array, q: Array, axis: Optional[Union[int, tuple[int, ...]]],
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
"'midpoint', or 'nearest'")
a, = promote_dtypes_inexact(a)
keepdim = []
if dtypes.issubdtype(a.dtype, np.complexfloating):
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
if axis is None:
a = a.ravel()
axis = 0
elif isinstance(axis, tuple):
keepdim = list(a.shape)
nd = a.ndim
axis = tuple(_canonicalize_axis(ax, nd) for ax in axis)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis')
for ax in axis:
keepdim[ax] = 1
keep = set(range(nd)) - set(axis)
# prepare permutation
dimensions = list(range(nd))
for i, s in enumerate(sorted(keep)):
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis)
touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis)
a = lax.reshape(a, do_not_touch_shape + (math.prod(touch_shape),), dimensions)
axis = _canonicalize_axis(-1, a.ndim)
else:
axis = _canonicalize_axis(axis, a.ndim)
q_shape = q.shape
q_ndim = q.ndim
if q_ndim > 1:
raise ValueError(f"q must be have rank <= 1, got shape {q.shape}")
a_shape = a.shape
if squash_nans:
a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
a = lax.sort(a, dimension=axis)
counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
shape_after_reduction = counts.shape
q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
counts = lax.expand_dims(counts, tuple(range(q_ndim)))
q = lax.mul(q, lax.sub(counts, _lax_const(q, 1)))
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
low = lax.convert_element_type(low, int)
high = lax.convert_element_type(high, int)
out_shape = q_shape + shape_after_reduction
index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim)
for dim in range(len(shape_after_reduction))]
if keepdims:
index[axis] = low
else:
index.insert(axis, low)
low_value = a[tuple(index)]
index[axis] = high
high_value = a[tuple(index)]
else:
a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a)
a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
q = lax.mul(q, n - 1)
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.clamp(_lax_const(low, 0), low, n - 1)
high = lax.clamp(_lax_const(high, 0), high, n - 1)
low = lax.convert_element_type(low, int)
high = lax.convert_element_type(high, int)
slice_sizes = list(a_shape)
slice_sizes[axis] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(range(
q_ndim,
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
collapsed_slice_dims=() if keepdims else (axis,),
start_index_map=(axis,))
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
if q_ndim == 1:
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
broadcast_dimensions=(0,))
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
broadcast_dimensions=(0,))
if interpolation == "linear":
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
lax.mul(high_value.astype(q.dtype), high_weight))
elif interpolation == "lower":
result = low_value
elif interpolation == "higher":
result = high_value
elif interpolation == "nearest":
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
result = lax.select(pred, low_value, high_value)
elif interpolation == "midpoint":
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
else:
raise ValueError(f"interpolation={interpolation!r} not recognized")
if keepdims and keepdim:
if q_ndim > 0:
keepdim = [np.shape(q)[0], *keepdim]
result = result.reshape(keepdim)
return lax.convert_element_type(result, a.dtype)
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("percentile", a, q)
q, = promote_dtypes_inexact(q)
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method, keepdims=keepdims)
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, 100.0)
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method,
keepdims=keepdims)
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, method='midpoint')
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
check_arraylike("nanmedian", a)
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
method='midpoint')