rocm_jax/jax/_src/numpy/reductions.py

686 lines
30 KiB
Python
Raw Normal View History

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
from functools import partial
import math
import operator
2022-12-05 15:42:26 -08:00
from typing import overload, Any, Callable, Literal, Optional, Sequence, Tuple, Union
import warnings
import numpy as np
from jax import lax
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.numpy.util import (
_broadcast_to, _check_arraylike, _complex_elem_type,
_promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
_all = builtins.all
_lax_const = lax_internal._const
Axis = Union[None, int, Sequence[int]]
def _asarray(a: ArrayLike) -> Array:
# simplified version of jnp.asarray() for local use.
return a if isinstance(a, Array) else api.device_put(a)
def _isscalar(element: Any) -> bool:
if hasattr(element, '__jax_array__'):
element = element.__jax_array__()
return dtypes.is_python_scalar(element) or np.isscalar(element)
def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
# simplified version of jnp.moveaxis() for local use.
_check_arraylike("moveaxis", a)
a = _asarray(a)
source = _canonicalize_axis(source, np.ndim(a))
destination = _canonicalize_axis(destination, np.ndim(a))
perm = [i for i in range(np.ndim(a)) if i != source]
perm.insert(destination, source)
return lax.transpose(a, perm)
def _upcast_f16(dtype: DTypeLike) -> DType:
if np.dtype(dtype) in [np.float16, dtypes.bfloat16]:
return np.dtype('float32')
return np.dtype(dtype)
ReductionOp = Callable[[Any, Any], Any]
def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
*, has_identity: bool = True,
preproc: Optional[Callable[[ArrayLike], ArrayLike]] = None,
bool_op: Optional[ReductionOp] = None,
upcast_f16_for_computation: bool = False,
axis: Axis = None, dtype: DTypeLike = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where_: Optional[ArrayLike] = None,
parallel_reduce: Optional[Callable[..., Array]] = None,
promote_integers: bool = False) -> Array:
bool_op = bool_op or op
# Note: we must accept out=None as an argument, because numpy reductions delegate to
# object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
# exists, passing along all its arguments.
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
_check_arraylike(name, a)
lax_internal._check_user_dtype_supported(dtype, name)
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
if initial is None and not has_identity and where_ is not None:
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
f"where mask one has to specify 'initial'")
a = a if isinstance(a, Array) else _asarray(a)
a = preproc(a) if preproc else a
pos_dims, dims = _reduction_dims(a, axis)
if initial is None and not has_identity:
shape = np.shape(a)
if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims):
raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
result_dtype = dtype or dtypes.dtype(a)
if dtype is None and promote_integers:
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
# default dtype as defined by dtypes.int_ or dtypes.uint.
if dtypes.issubdtype(result_dtype, np.bool_):
result_dtype = dtypes.int_
elif dtypes.issubdtype(result_dtype, np.unsignedinteger):
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits:
result_dtype = dtypes.uint
elif dtypes.issubdtype(result_dtype, np.integer):
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits:
result_dtype = dtypes.int_
result_dtype = dtypes.canonicalize_dtype(result_dtype)
if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
computation_dtype = _upcast_f16(result_dtype)
else:
computation_dtype = result_dtype
a = lax.convert_element_type(a, computation_dtype)
op = op if computation_dtype != np.bool_ else bool_op
# NB: in XLA, init_val must be an identity for the op, so the user-specified
# initial value must be applied afterward.
init_val = _reduction_init_val(a, init_val)
if where_ is not None:
a = _where(where_, a, init_val)
if pos_dims is not dims:
if parallel_reduce is None:
raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
result = parallel_reduce(a, dims)
else:
result = lax.reduce(a, init_val, op, dims)
if initial is not None:
initial_arr = lax.convert_element_type(initial, _asarray(a).dtype)
if initial_arr.shape != ():
raise ValueError("initial value must be a scalar. "
f"Got array of shape {initial_arr.shape}")
result = op(initial_arr, result)
if keepdims:
result = lax.expand_dims(result, pos_dims)
return lax.convert_element_type(result, dtype or result_dtype)
def _canonicalize_axis_allow_named(x, rank):
return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)
def _reduction_dims(a: ArrayLike, axis: Axis):
if axis is None:
return (tuple(range(np.ndim(a))),) * 2
elif not isinstance(axis, (np.ndarray, tuple, list)):
axis = (axis,) # type: ignore[assignment]
canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
for x in axis) # type: ignore[union-attr]
if len(canon_axis) != len(set(canon_axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
if len(canon_pos_axis) != len(canon_axis):
return canon_pos_axis, canon_axis
else:
return canon_axis, canon_axis
def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
# This function uses np.* functions because lax pattern matches against the
# specific concrete values of the reduction inputs.
a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
if a_dtype == 'bool':
return np.array(init_val > 0, dtype=a_dtype)
try:
return np.array(init_val, dtype=a_dtype)
except OverflowError:
assert dtypes.issubdtype(a_dtype, np.integer)
sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
def _cast_to_bool(operand: ArrayLike) -> Array:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=np.ComplexWarning)
return lax.convert_element_type(operand, np.bool_)
def _cast_to_numeric(operand: ArrayLike) -> Array:
return _promote_dtypes_numeric(operand)[0]
def _ensure_optional_axes(x: Axis) -> Axis:
def force(x):
if x is None:
return None
try:
return operator.index(x)
except TypeError:
return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
return core.concrete_or_error(
force, x, "The axis argument must be known statically.")
# TODO(jakevdp) change promote_integers default to False
_PROMOTE_INTEGERS_DOC = """
promote_integers : bool, default=True
If True, then integer inputs will be promoted to the widest available integer
dtype, following numpy's behavior. If False, the result will have the same dtype
as the input. ``promote_integers`` is ignored if ``dtype`` is specified.
"""
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, keepdims: bool = False,
initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)
@_wraps(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None, promote_integers: bool = True) -> Array:
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, keepdims: bool = False,
initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, promote_integers=promote_integers)
@_wraps(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, keepdims: bool = False,
initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None,
promote_integers: bool = True) -> Array:
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
out=out, keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmax)
@_wraps(np.max, skip_params=['out'])
def max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmin)
@_wraps(np.min, skip_params=['out'])
def min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, initial=initial, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array:
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@_wraps(np.all, skip_params=['out'])
def all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array:
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array:
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@_wraps(np.any, skip_params=['out'])
def any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array:
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
product = prod
amin = min
amax = max
alltrue = all
sometrue = any
def _axis_size(a: ArrayLike, axis: Union[int, Sequence[int]]):
if not isinstance(axis, (tuple, list)):
axis_seq: Sequence[int] = (axis,) # type: ignore[assignment]
else:
axis_seq = axis
size = 1
a_shape = np.shape(a)
for a in axis_seq:
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
return size
@_wraps(np.mean, skip_params=['out'])
def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("mean", a)
lax_internal._check_user_dtype_supported(dtype, "mean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
if where is None:
if axis is None:
normalizer = core.dimension_as_value(np.size(a))
else:
normalizer = core.dimension_as_value(_axis_size(a, axis))
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
if dtype is None:
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a))
dtype = dtypes.canonicalize_dtype(dtype)
return lax.div(
sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
lax.convert_element_type(normalizer, dtype))
@overload
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: Literal[False] = False, keepdims: bool = False) -> Array: ...
@overload
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, *,
returned: Literal[True], keepdims: bool = False) -> Array: ...
@overload
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]: ...
@_wraps(np.average)
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]:
2022-06-28 10:55:55 -07:00
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
2022-06-28 10:55:55 -07:00
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]:
if weights is None: # Treat all weights as 1
_check_arraylike("average", a)
a, = _promote_dtypes_inexact(a)
2022-06-28 10:55:55 -07:00
avg = mean(a, axis=axis, keepdims=keepdims)
if axis is None:
weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype)
2022-10-06 10:20:46 -07:00
elif isinstance(axis, tuple):
weights_sum = lax.full_like(avg, math.prod(core.dimension_as_value(a.shape[d]) for d in axis))
else:
weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index]
else:
_check_arraylike("average", a, weights)
a, weights = _promote_dtypes_inexact(a, weights)
a_shape = np.shape(a)
a_ndim = len(a_shape)
weights_shape = np.shape(weights)
2022-10-06 10:20:46 -07:00
if axis is None:
pass
elif isinstance(axis, tuple):
axis = tuple(_canonicalize_axis(d, a_ndim) for d in axis)
else:
axis = _canonicalize_axis(axis, a_ndim)
if a_shape != weights_shape:
# Make sure the dimensions work out
if len(weights_shape) != 1:
raise ValueError("1D weights expected when shapes of a and "
"weights differ.")
2022-10-06 10:20:46 -07:00
if axis is None:
raise ValueError("Axis must be specified when shapes of a and "
"weights differ.")
elif isinstance(axis, tuple):
raise ValueError("Single axis expected when shapes of a and weights differ")
elif not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
raise ValueError("Length of weights not "
"compatible with specified axis.")
weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
weights = _moveaxis(weights, -1, axis)
2022-06-28 10:55:55 -07:00
weights_sum = sum(weights, axis=axis, keepdims=keepdims)
avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum
if returned:
if avg.shape != weights_sum.shape:
weights_sum = _broadcast_to(weights_sum, avg.shape)
return avg, weights_sum
return avg
@_wraps(np.var, skip_params=['out'])
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("var", a)
lax_internal._check_user_dtype_supported(dtype, "var")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
a = _asarray(a).astype(computation_dtype)
a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
centered = lax.sub(a, a_mean)
if dtypes.issubdtype(centered.dtype, np.complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
else:
centered = lax.square(centered)
if where is None:
if axis is None:
normalizer = core.dimension_as_value(np.size(a))
else:
normalizer = core.dimension_as_value(_axis_size(a, axis))
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
normalizer = normalizer - ddof
result = sum(centered, axis, keepdims=keepdims, where=where)
result = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
return lax.convert_element_type(result, dtype)
def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike) -> Tuple[DType, DType]:
if dtype:
if (not dtypes.issubdtype(dtype, np.complexfloating) and
dtypes.issubdtype(a_dtype, np.complexfloating)):
msg = ("jax.numpy.var does not yet support real dtype parameters when "
"computing the variance of an array of complex values. The "
"semantics of numpy.var seem unclear in this case. Please comment "
"on https://github.com/google/jax/issues/2283 if this behavior is "
"important to you.")
raise ValueError(msg)
computation_dtype = dtype
else:
if not dtypes.issubdtype(a_dtype, np.inexact):
dtype = dtypes.to_inexact_dtype(a_dtype)
computation_dtype = dtype
else:
dtype = _complex_elem_type(a_dtype)
computation_dtype = a_dtype
return _upcast_f16(computation_dtype), np.dtype(dtype)
@_wraps(np.std, skip_params=['out'])
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("std", a)
lax_internal._check_user_dtype_supported(dtype, "std")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
@_wraps(np.ptp, skip_params=['out'])
def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
_check_arraylike("ptp", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.")
x = amax(a, axis=axis, keepdims=keepdims)
y = amin(a, axis=axis, keepdims=keepdims)
return lax.sub(x, y)
@_wraps(np.count_nonzero)
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def count_nonzero(a: ArrayLike, axis: Axis = None,
keepdims: bool = False) -> Array:
_check_arraylike("count_nonzero", a)
return sum(lax.ne(a, _lax_const(a, 0)), axis=axis,
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
init_val: ArrayLike, nan_if_all_nan: bool,
axis: Axis = None, keepdims: bool = False, **kwargs) -> Array:
_check_arraylike(name, a)
if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a),
axis=axis, keepdims=keepdims, **kwargs)
if nan_if_all_nan:
return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
_lax_const(a, np.nan), out)
else:
return out
@_wraps(np.nanmin, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _nan_reduction(a, 'nanmin', min, np.inf, nan_if_all_nan=initial is None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@_wraps(np.nanmax, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
return _nan_reduction(a, 'nanmax', max, -np.inf, nan_if_all_nan=initial is None,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
@_wraps(np.nansum, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
# Work around a sphinx documentation warning in NumPy 1.22.
if nansum.__doc__ is not None:
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
@_wraps(np.nanprod, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None,
keepdims: bool = False, initial: Optional[ArrayLike] = None,
where: Optional[ArrayLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where=where)
@_wraps(np.nanmean, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None,
keepdims: bool = False, where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("nanmean", a)
lax_internal._check_user_dtype_supported(dtype, "nanmean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer):
return mean(a, axis, dtype, out, keepdims, where=where)
if dtype is None:
dtype = dtypes.dtype(a)
nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a))
normalizer = sum(nan_mask, axis=axis, dtype=np.int32, keepdims=keepdims, where=where)
normalizer = lax.convert_element_type(normalizer, dtype)
td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer)
return td
@_wraps(np.nanvar, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("nanvar", a)
lax_internal._check_user_dtype_supported(dtype, "nanvar")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
a = _asarray(a).astype(computation_dtype)
a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients.
if dtypes.issubdtype(centered.dtype, np.complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
else:
centered = lax.square(centered)
normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
axis=axis, keepdims=keepdims, where=where)
normalizer = normalizer - ddof
normalizer_mask = lax.le(normalizer, lax_internal._zero(normalizer))
result = sum(centered, axis, keepdims=keepdims, where=where)
result = _where(normalizer_mask, np.nan, result)
divisor = _where(normalizer_mask, 1, normalizer)
result = lax.div(result, lax.convert_element_type(divisor, result.dtype))
return lax.convert_element_type(result, dtype)
@_wraps(np.nanstd, skip_params=['out'])
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
where: Optional[ArrayLike] = None) -> Array:
_check_arraylike("nanstd", a)
lax_internal._check_user_dtype_supported(dtype, "nanstd")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
# TODO(jakevdp): use a protocol here for better typing?
def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array],
fill_nan: bool = False, fill_value: ArrayLike = 0) -> Callable[..., Array]:
@_wraps(np_reduction, skip_params=['out'])
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
dtype: DTypeLike = None, out: None = None) -> Array:
return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
@partial(api.jit, static_argnames=('axis', 'dtype'))
def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
dtype: DTypeLike = None, out: None = None) -> Array:
_check_arraylike(np_reduction.__name__, a)
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
f"is not supported.")
lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)
if axis is None or _isscalar(a):
a = lax.reshape(a, (np.size(a),))
axis = 0
a_shape = list(np.shape(a))
num_dims = len(a_shape)
axis = _canonicalize_axis(axis, num_dims)
if fill_nan:
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)
if not dtype and dtypes.dtype(a) == np.bool_:
dtype = dtypes.canonicalize_dtype(dtypes.int_)
if dtype:
a = lax.convert_element_type(a, dtype)
return reduction(a, axis)
return cumulative_reduction
cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False)
cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False)
cumproduct = cumprod
nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
fill_nan=True, fill_value=0)
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
fill_nan=True, fill_value=1)