mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Factor-out reductions from lax_numpy.py
This commit is contained in:
parent
e9f59aed84
commit
121d8d6320
@ -51,6 +51,12 @@ from jax._src.api_util import _ensure_index_tuple
|
||||
from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.numpy.reductions import ( # noqa: F401
|
||||
_ensure_optional_axes, _reduction_dims,
|
||||
alltrue, amin, amax, any, all, average, count_nonzero, cumsum, cumprod, cumproduct,
|
||||
max, mean, min, nancumsum, nancumprod, nanmax, nanmean, nanmin, nanprod, nanstd,
|
||||
nansum, nanvar, prod, product, ptp, sometrue, std, sum, var,
|
||||
)
|
||||
from jax._src.numpy.ufuncs import ( # noqa: F401
|
||||
abs, absolute, add, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh,
|
||||
bitwise_and, bitwise_not, bitwise_or, bitwise_xor, cbrt, ceil, conj, conjugate,
|
||||
@ -69,7 +75,7 @@ from jax._src.numpy.util import ( # noqa: F401
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
|
||||
canonicalize_axis as _canonicalize_axis)
|
||||
|
||||
newaxis = None
|
||||
|
||||
@ -1341,375 +1347,6 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
|
||||
x = where(isneginf(x), array(neginf, dtype=x.dtype), x)
|
||||
return x
|
||||
|
||||
### Reducers
|
||||
|
||||
def _reduction(a, name, np_fun, op, init_val, has_identity=True,
|
||||
preproc=None, bool_op=None, upcast_f16_for_computation=False,
|
||||
axis=None, dtype=None, out=None, keepdims=False, initial=None,
|
||||
where_=None, parallel_reduce=None):
|
||||
bool_op = bool_op or op
|
||||
# Note: we must accept out=None as an argument, because numpy reductions delegate to
|
||||
# object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
|
||||
# exists, passing along all its arguments.
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
|
||||
_check_arraylike(name, a)
|
||||
lax_internal._check_user_dtype_supported(dtype, name)
|
||||
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
|
||||
|
||||
if initial is None and not has_identity:
|
||||
if not _all(core.greater_equal_dim(d, 1) for d in np.shape(a)):
|
||||
raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
|
||||
if where_ is not None:
|
||||
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
|
||||
f"where mask one has to specify 'initial'")
|
||||
|
||||
a = a if isinstance(a, ndarray) else asarray(a)
|
||||
a = preproc(a) if preproc else a
|
||||
pos_dims, dims = _reduction_dims(a, axis)
|
||||
result_dtype = dtypes.canonicalize_dtype(dtype or _dtype(np_fun(np.ones((), dtype=_dtype(a)))))
|
||||
if upcast_f16_for_computation and issubdtype(result_dtype, inexact):
|
||||
computation_dtype = promote_types(result_dtype, float32)
|
||||
else:
|
||||
computation_dtype = result_dtype
|
||||
a = lax.convert_element_type(a, computation_dtype)
|
||||
op = op if computation_dtype != np.bool_ else bool_op
|
||||
# NB: in XLA, init_val must be an identity for the op, so the user-specified
|
||||
# initial value must be applied afterward.
|
||||
init_val = _reduction_init_val(a, init_val)
|
||||
if where_ is not None:
|
||||
a = where(where_, a, init_val)
|
||||
if pos_dims is not dims:
|
||||
if parallel_reduce is None:
|
||||
raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
|
||||
result = parallel_reduce(a, dims)
|
||||
else:
|
||||
result = lax.reduce(a, init_val, op, dims)
|
||||
if initial is not None:
|
||||
result = op(lax.convert_element_type(initial, a.dtype), result)
|
||||
if keepdims:
|
||||
result = expand_dims(result, pos_dims)
|
||||
return lax.convert_element_type(result, dtype or result_dtype)
|
||||
|
||||
def _canonicalize_axis_allow_named(x, rank):
|
||||
return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)
|
||||
|
||||
def _reduction_dims(a, axis):
|
||||
if axis is None:
|
||||
return (tuple(range(ndim(a))),) * 2
|
||||
elif not isinstance(axis, (np.ndarray, tuple, list)):
|
||||
axis = (axis,)
|
||||
canon_axis = tuple(_canonicalize_axis_allow_named(x, ndim(a))
|
||||
for x in axis)
|
||||
if len(canon_axis) != len(set(canon_axis)):
|
||||
raise ValueError(f"duplicate value in 'axis': {axis}")
|
||||
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
|
||||
if len(canon_pos_axis) != len(canon_axis):
|
||||
return canon_pos_axis, canon_axis
|
||||
else:
|
||||
return canon_axis, canon_axis
|
||||
|
||||
def _reduction_init_val(a, init_val):
|
||||
# This function uses np.* functions because lax pattern matches against the
|
||||
# specific concrete values of the reduction inputs.
|
||||
a_dtype = dtypes.canonicalize_dtype(_dtype(a))
|
||||
if a_dtype == 'bool':
|
||||
return np.array(init_val > 0, dtype=a_dtype)
|
||||
try:
|
||||
return np.array(init_val, dtype=a_dtype)
|
||||
except OverflowError:
|
||||
assert issubdtype(a_dtype, integer)
|
||||
sign, info = np.sign(init_val), iinfo(a_dtype)
|
||||
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
|
||||
|
||||
def _cast_to_bool(operand):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=np.ComplexWarning)
|
||||
return lax.convert_element_type(operand, bool_)
|
||||
|
||||
|
||||
def _ensure_optional_axes(x):
|
||||
def force(x):
|
||||
if x is None:
|
||||
return None
|
||||
try:
|
||||
return operator.index(x)
|
||||
except TypeError:
|
||||
return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
|
||||
return core.concrete_or_error(
|
||||
force, x, "The axis argument must be known statically.")
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
|
||||
def _reduce_sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "sum", np.sum, lax.add, 0,
|
||||
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.psum)
|
||||
|
||||
@_wraps(np.sum, skip_params=['out'])
|
||||
def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
|
||||
keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
|
||||
def _reduce_prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "prod", np.prod, lax.mul, 1,
|
||||
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where)
|
||||
|
||||
@_wraps(np.prod, skip_params=['out'])
|
||||
def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
|
||||
out=out, keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmax)
|
||||
|
||||
@_wraps(np.max, skip_params=['out'])
|
||||
def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmin)
|
||||
|
||||
@_wraps(np.min, skip_params=['out'])
|
||||
def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
|
||||
axis=axis, out=out, keepdims=keepdims, where_=where)
|
||||
|
||||
@_wraps(np.all, skip_params=['out'])
|
||||
def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
keepdims=keepdims, where=where)
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
|
||||
axis=axis, out=out, keepdims=keepdims, where_=where)
|
||||
|
||||
@_wraps(np.any, skip_params=['out'])
|
||||
def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
keepdims=keepdims, where=where)
|
||||
|
||||
product = prod
|
||||
amin = min
|
||||
amax = max
|
||||
alltrue = all
|
||||
sometrue = any
|
||||
|
||||
def _axis_size(a, axis):
|
||||
if not isinstance(axis, (tuple, list)):
|
||||
axis = (axis,)
|
||||
size = 1
|
||||
a_shape = shape(a)
|
||||
for a in axis:
|
||||
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
|
||||
return size
|
||||
|
||||
@_wraps(np.mean, skip_params=['out'])
|
||||
def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None):
|
||||
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
|
||||
where=where)
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
|
||||
def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None):
|
||||
_check_arraylike("mean", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "mean")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
|
||||
|
||||
if where is None:
|
||||
if axis is None:
|
||||
normalizer = core.dimension_as_value(size(a))
|
||||
else:
|
||||
normalizer = core.dimension_as_value(_axis_size(a, axis))
|
||||
else:
|
||||
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
|
||||
if dtype is None:
|
||||
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
|
||||
dtype = float_
|
||||
else:
|
||||
dtype = _dtype(a)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
||||
return lax.div(
|
||||
sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
|
||||
lax.convert_element_type(normalizer, dtype))
|
||||
|
||||
@_wraps(np.average)
|
||||
def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
|
||||
returned=False):
|
||||
return _average(a, _ensure_optional_axes(axis), weights, returned)
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'returned'), inline=True)
|
||||
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
|
||||
returned=False):
|
||||
a = asarray(a)
|
||||
|
||||
if weights is None: # Treat all weights as 1
|
||||
avg = mean(a, axis=axis)
|
||||
if axis is None:
|
||||
weights_sum = full((), core.dimension_as_value(size(a)), dtype=avg.dtype)
|
||||
else:
|
||||
weights_sum = full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype)
|
||||
else:
|
||||
weights = asarray(weights)
|
||||
|
||||
if issubdtype(a.dtype, inexact):
|
||||
out_dtype = result_type(a.dtype, weights.dtype)
|
||||
else:
|
||||
out_dtype = result_type(a.dtype, weights.dtype, float_)
|
||||
out_dtype = dtypes.canonicalize_dtype(out_dtype)
|
||||
|
||||
a_shape = shape(a)
|
||||
a_ndim = len(a_shape)
|
||||
weights_shape = shape(weights)
|
||||
axis = None if axis is None else _canonicalize_axis(axis, a_ndim)
|
||||
|
||||
if a_shape != weights_shape:
|
||||
# Make sure the dimensions work out
|
||||
if axis is None:
|
||||
raise ValueError("Axis must be specified when shapes of a and "
|
||||
"weights differ.")
|
||||
if len(weights_shape) != 1:
|
||||
raise ValueError("1D weights expected when shapes of a and "
|
||||
"weights differ.")
|
||||
if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
|
||||
raise ValueError("Length of weights not "
|
||||
"compatible with specified axis.")
|
||||
|
||||
weights = broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
|
||||
weights = moveaxis(weights, -1, axis)
|
||||
|
||||
weights_sum = sum(weights, axis=axis, dtype=out_dtype)
|
||||
avg = sum(multiply(a, weights), axis=axis, dtype=out_dtype) / weights_sum
|
||||
|
||||
if returned:
|
||||
if avg.shape != weights_sum.shape:
|
||||
weights_sum = broadcast_to(weights_sum, avg.shape)
|
||||
return avg, weights_sum
|
||||
return avg
|
||||
|
||||
|
||||
@_wraps(np.var, skip_params=['out'])
|
||||
def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
|
||||
where=where)
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
_check_arraylike("var", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "var")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
|
||||
|
||||
a_dtype, dtype = _var_promote_types(_dtype(a), dtype)
|
||||
a_mean = mean(a, axis, dtype=a_dtype, keepdims=True, where=where)
|
||||
centered = a - a_mean
|
||||
if issubdtype(centered.dtype, complexfloating):
|
||||
centered = lax.real(lax.mul(centered, lax.conj(centered)))
|
||||
else:
|
||||
centered = lax.square(centered)
|
||||
|
||||
if where is None:
|
||||
if axis is None:
|
||||
normalizer = core.dimension_as_value(size(a))
|
||||
else:
|
||||
normalizer = core.dimension_as_value(_axis_size(a, axis))
|
||||
else:
|
||||
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
normalizer = normalizer - ddof
|
||||
|
||||
result = sum(centered, axis, keepdims=keepdims, where=where)
|
||||
out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
|
||||
return lax.convert_element_type(out, dtype)
|
||||
|
||||
|
||||
def _var_promote_types(a_dtype, dtype):
|
||||
if dtype:
|
||||
if (not issubdtype(dtype, complexfloating) and
|
||||
issubdtype(a_dtype, complexfloating)):
|
||||
msg = ("jax.numpy.var does not yet support real dtype parameters when "
|
||||
"computing the variance of an array of complex values. The "
|
||||
"semantics of numpy.var seem unclear in this case. Please comment "
|
||||
"on https://github.com/google/jax/issues/2283 if this behavior is "
|
||||
"important to you.")
|
||||
raise ValueError(msg)
|
||||
a_dtype = promote_types(a_dtype, dtype)
|
||||
else:
|
||||
if not issubdtype(a_dtype, inexact):
|
||||
dtype = a_dtype = dtypes.canonicalize_dtype(float_)
|
||||
else:
|
||||
dtype = _complex_elem_type(a_dtype)
|
||||
a_dtype = promote_types(a_dtype, float32)
|
||||
return a_dtype, dtype
|
||||
|
||||
|
||||
@_wraps(np.std, skip_params=['out'])
|
||||
def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
|
||||
where=where)
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def _std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
_check_arraylike("std", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "std")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
|
||||
return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
|
||||
|
||||
|
||||
@_wraps(np.ptp, skip_params=['out'])
|
||||
def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=False):
|
||||
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'))
|
||||
def _ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=False):
|
||||
_check_arraylike("ptp", a)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.")
|
||||
x = amax(a, axis=axis, keepdims=keepdims)
|
||||
y = amin(a, axis=axis, keepdims=keepdims)
|
||||
return lax.sub(x, y)
|
||||
|
||||
|
||||
@_wraps(np.allclose)
|
||||
@partial(jit, static_argnames=('equal_nan',))
|
||||
@ -1718,15 +1355,6 @@ def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
|
||||
return all(isclose(a, b, rtol, atol, equal_nan))
|
||||
|
||||
|
||||
@_wraps(np.count_nonzero)
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'))
|
||||
def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims=False):
|
||||
_check_arraylike("count_nonzero", a)
|
||||
return sum(lax.ne(a, _lax_const(a, 0)), axis=axis,
|
||||
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
|
||||
|
||||
|
||||
_NONZERO_DOC = """\
|
||||
Because the size of the output of ``nonzero`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
@ -1770,162 +1398,6 @@ def flatnonzero(a, *, size=None, fill_value=None):
|
||||
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]
|
||||
|
||||
|
||||
def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan,
|
||||
axis=None, keepdims=None, **kwargs):
|
||||
_check_arraylike(name, a)
|
||||
if not issubdtype(_dtype(a), inexact):
|
||||
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
|
||||
|
||||
out = jnp_reduction(where(isnan(a), _reduction_init_val(a, init_val), a),
|
||||
axis=axis, keepdims=keepdims, **kwargs)
|
||||
if nan_if_all_nan:
|
||||
return where(all(isnan(a), axis=axis, keepdims=keepdims),
|
||||
_lax_const(a, nan), out)
|
||||
else:
|
||||
return out
|
||||
|
||||
@_wraps(np.nanmin, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'))
|
||||
def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _nan_reduction(a, 'nanmin', min, inf, nan_if_all_nan=initial is None,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
@_wraps(np.nanmax, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis', 'keepdims'))
|
||||
def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _nan_reduction(a, 'nanmax', max, -inf, nan_if_all_nan=initial is None,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
@_wraps(np.nansum, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanprod")
|
||||
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
# Work around a sphinx documentation warning in NumPy 1.22.
|
||||
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
|
||||
|
||||
@_wraps(np.nanprod, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanprod")
|
||||
return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
@_wraps(np.nanmean, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, where=None):
|
||||
_check_arraylike("nanmean", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanmean")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
|
||||
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
|
||||
return mean(a, axis, dtype, out, keepdims, where=where)
|
||||
if dtype is None:
|
||||
dtype = _dtype(a)
|
||||
nan_mask = logical_not(isnan(a))
|
||||
normalizer = sum(nan_mask, axis=axis, dtype=int32, keepdims=keepdims, where=where)
|
||||
normalizer = lax.convert_element_type(normalizer, dtype)
|
||||
td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer)
|
||||
return td
|
||||
|
||||
|
||||
@_wraps(np.nanvar, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, where=None):
|
||||
_check_arraylike("nanvar", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanvar")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
|
||||
|
||||
a_dtype, dtype = _var_promote_types(_dtype(a), dtype)
|
||||
a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)
|
||||
|
||||
centered = _where(isnan(a), 0, a - a_mean) # double-where trick for gradients.
|
||||
if issubdtype(centered.dtype, complexfloating):
|
||||
centered = lax.real(lax.mul(centered, lax.conj(centered)))
|
||||
else:
|
||||
centered = lax.square(centered)
|
||||
|
||||
normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims, where=where)
|
||||
normalizer = normalizer - ddof
|
||||
normalizer_mask = lax.le(normalizer, 0)
|
||||
result = sum(centered, axis, keepdims=keepdims, where=where)
|
||||
result = _where(normalizer_mask, nan, result)
|
||||
divisor = _where(normalizer_mask, 1, normalizer)
|
||||
out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
|
||||
return lax.convert_element_type(out, dtype)
|
||||
|
||||
|
||||
@_wraps(np.nanstd, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, where=None):
|
||||
_check_arraylike("nanstd", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanstd")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
|
||||
return sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
|
||||
|
||||
|
||||
def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_value=0):
|
||||
@_wraps(np_reduction, skip_params=['out'])
|
||||
def cumulative_reduction(a,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None):
|
||||
return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
|
||||
|
||||
@partial(jit, static_argnames=('axis', 'dtype'))
|
||||
def _cumulative_reduction(a,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None):
|
||||
_check_arraylike(np_reduction.__name__, a)
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
|
||||
f"is not supported.")
|
||||
lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)
|
||||
|
||||
if axis is None or isscalar(a):
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
|
||||
a_shape = list(shape(a))
|
||||
num_dims = len(a_shape)
|
||||
axis = _canonicalize_axis(axis, num_dims)
|
||||
|
||||
if fill_nan:
|
||||
a = where(isnan(a), _lax_const(a, fill_value), a)
|
||||
|
||||
if not dtype and _dtype(a) == bool_:
|
||||
dtype = int_
|
||||
if dtype:
|
||||
a = lax.convert_element_type(a, dtype)
|
||||
|
||||
return reduction(a, axis)
|
||||
|
||||
return cumulative_reduction
|
||||
|
||||
|
||||
cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False)
|
||||
cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False)
|
||||
cumproduct = cumprod
|
||||
nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
|
||||
fill_nan=True, fill_value=0)
|
||||
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
|
||||
fill_nan=True, fill_value=1)
|
||||
|
||||
|
||||
@_wraps(np.unwrap)
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def unwrap(p, discont=pi, axis: int = -1):
|
||||
@ -5163,7 +4635,7 @@ def _itemsize(arr):
|
||||
return _dtype(arr).itemsize
|
||||
|
||||
|
||||
def _clip(number, min=None, max=None, out=None, *, a_min=None, a_max=None):
|
||||
def _clip(number, min=None, max=None, out=None, *, a_min=None, a_max=None): # noqa: F811
|
||||
# ndarray.clip has a slightly different API from clip (min -> a_min, max -> a_max)
|
||||
# TODO: remove after deprecation window
|
||||
if a_min is not None or a_max is not None:
|
||||
@ -5609,7 +5081,7 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode))
|
||||
|
||||
def min(self, values, indices_are_sorted=False, unique_indices=False,
|
||||
def min(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
|
||||
|
||||
@ -5624,7 +5096,7 @@ class _IndexUpdateRef:
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
|
||||
def max(self, values, indices_are_sorted=False, unique_indices=False,
|
||||
def max(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811
|
||||
mode=None):
|
||||
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
|
||||
|
||||
|
588
jax/_src/numpy/reductions.py
Normal file
588
jax/_src/numpy/reductions.py
Normal file
@ -0,0 +1,588 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.numpy.util import _broadcast_to, _check_arraylike, _complex_elem_type, _where, _wraps
|
||||
from jax._src.numpy.ufuncs import isnan, logical_not
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.util import canonicalize_axis as _canonicalize_axis, maybe_named_axis
|
||||
|
||||
|
||||
_all = builtins.all
|
||||
_lax_const = lax_internal._const
|
||||
|
||||
|
||||
def _asarray(a):
|
||||
# simplified version of jnp.asarray() for local use.
|
||||
return a if isinstance(a, ndarray) else api.device_put(a)
|
||||
|
||||
def _isscalar(element):
|
||||
if hasattr(element, '__jax_array__'):
|
||||
element = element.__jax_array__()
|
||||
return dtypes.is_python_scalar(element) or np.isscalar(element)
|
||||
|
||||
def _moveaxis(a, source: int, destination: int):
|
||||
# simplified version of jnp.moveaxis() for local use.
|
||||
_check_arraylike("moveaxis", a)
|
||||
a = _asarray(a)
|
||||
source = _canonicalize_axis(source, np.ndim(a))
|
||||
destination = _canonicalize_axis(destination, np.ndim(a))
|
||||
perm = [i for i in range(np.ndim(a)) if i != source]
|
||||
perm.insert(destination, source)
|
||||
return lax.transpose(a, perm)
|
||||
|
||||
|
||||
def _reduction(a, name, np_fun, op, init_val, has_identity=True,
|
||||
preproc=None, bool_op=None, upcast_f16_for_computation=False,
|
||||
axis=None, dtype=None, out=None, keepdims=False, initial=None,
|
||||
where_=None, parallel_reduce=None):
|
||||
bool_op = bool_op or op
|
||||
# Note: we must accept out=None as an argument, because numpy reductions delegate to
|
||||
# object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
|
||||
# exists, passing along all its arguments.
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
|
||||
_check_arraylike(name, a)
|
||||
lax_internal._check_user_dtype_supported(dtype, name)
|
||||
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
|
||||
|
||||
if initial is None and not has_identity:
|
||||
if not _all(core.greater_equal_dim(d, 1) for d in np.shape(a)):
|
||||
raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
|
||||
if where_ is not None:
|
||||
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
|
||||
f"where mask one has to specify 'initial'")
|
||||
|
||||
a = a if isinstance(a, ndarray) else _asarray(a)
|
||||
a = preproc(a) if preproc else a
|
||||
pos_dims, dims = _reduction_dims(a, axis)
|
||||
result_dtype = dtypes.canonicalize_dtype(dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a)))))
|
||||
if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
|
||||
computation_dtype = dtypes.promote_types(result_dtype, np.float32)
|
||||
else:
|
||||
computation_dtype = result_dtype
|
||||
a = lax.convert_element_type(a, computation_dtype)
|
||||
op = op if computation_dtype != np.bool_ else bool_op
|
||||
# NB: in XLA, init_val must be an identity for the op, so the user-specified
|
||||
# initial value must be applied afterward.
|
||||
init_val = _reduction_init_val(a, init_val)
|
||||
if where_ is not None:
|
||||
a = _where(where_, a, init_val)
|
||||
if pos_dims is not dims:
|
||||
if parallel_reduce is None:
|
||||
raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
|
||||
result = parallel_reduce(a, dims)
|
||||
else:
|
||||
result = lax.reduce(a, init_val, op, dims)
|
||||
if initial is not None:
|
||||
result = op(lax.convert_element_type(initial, a.dtype), result)
|
||||
if keepdims:
|
||||
result = lax.expand_dims(result, pos_dims)
|
||||
return lax.convert_element_type(result, dtype or result_dtype)
|
||||
|
||||
def _canonicalize_axis_allow_named(x, rank):
|
||||
return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)
|
||||
|
||||
def _reduction_dims(a, axis):
|
||||
if axis is None:
|
||||
return (tuple(range(np.ndim(a))),) * 2
|
||||
elif not isinstance(axis, (np.ndarray, tuple, list)):
|
||||
axis = (axis,)
|
||||
canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
|
||||
for x in axis)
|
||||
if len(canon_axis) != len(set(canon_axis)):
|
||||
raise ValueError(f"duplicate value in 'axis': {axis}")
|
||||
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
|
||||
if len(canon_pos_axis) != len(canon_axis):
|
||||
return canon_pos_axis, canon_axis
|
||||
else:
|
||||
return canon_axis, canon_axis
|
||||
|
||||
def _reduction_init_val(a, init_val):
|
||||
# This function uses np.* functions because lax pattern matches against the
|
||||
# specific concrete values of the reduction inputs.
|
||||
a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
|
||||
if a_dtype == 'bool':
|
||||
return np.array(init_val > 0, dtype=a_dtype)
|
||||
try:
|
||||
return np.array(init_val, dtype=a_dtype)
|
||||
except OverflowError:
|
||||
assert dtypes.issubdtype(a_dtype, np.integer)
|
||||
sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
|
||||
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
|
||||
|
||||
def _cast_to_bool(operand):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=np.ComplexWarning)
|
||||
return lax.convert_element_type(operand, np.bool_)
|
||||
|
||||
|
||||
def _ensure_optional_axes(x):
|
||||
def force(x):
|
||||
if x is None:
|
||||
return None
|
||||
try:
|
||||
return operator.index(x)
|
||||
except TypeError:
|
||||
return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
|
||||
return core.concrete_or_error(
|
||||
force, x, "The axis argument must be known statically.")
|
||||
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
|
||||
def _reduce_sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "sum", np.sum, lax.add, 0,
|
||||
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.psum)
|
||||
|
||||
@_wraps(np.sum, skip_params=['out'])
|
||||
def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
|
||||
keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
|
||||
def _reduce_prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "prod", np.prod, lax.mul, 1,
|
||||
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where)
|
||||
|
||||
@_wraps(np.prod, skip_params=['out'])
|
||||
def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
|
||||
out=out, keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmax)
|
||||
|
||||
@_wraps(np.max, skip_params=['out'])
|
||||
def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmin)
|
||||
|
||||
@_wraps(np.min, skip_params=['out'])
|
||||
def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
|
||||
axis=axis, out=out, keepdims=keepdims, where_=where)
|
||||
|
||||
@_wraps(np.all, skip_params=['out'])
|
||||
def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
keepdims=keepdims, where=where)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
|
||||
axis=axis, out=out, keepdims=keepdims, where_=where)
|
||||
|
||||
@_wraps(np.any, skip_params=['out'])
|
||||
def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
|
||||
keepdims=keepdims, where=where)
|
||||
|
||||
product = prod
|
||||
amin = min
|
||||
amax = max
|
||||
alltrue = all
|
||||
sometrue = any
|
||||
|
||||
def _axis_size(a, axis):
|
||||
if not isinstance(axis, (tuple, list)):
|
||||
axis = (axis,)
|
||||
size = 1
|
||||
a_shape = np.shape(a)
|
||||
for a in axis:
|
||||
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
|
||||
return size
|
||||
|
||||
@_wraps(np.mean, skip_params=['out'])
|
||||
def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None):
|
||||
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
|
||||
where=where)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
|
||||
def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None):
|
||||
_check_arraylike("mean", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "mean")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
|
||||
|
||||
if where is None:
|
||||
if axis is None:
|
||||
normalizer = core.dimension_as_value(np.size(a))
|
||||
else:
|
||||
normalizer = core.dimension_as_value(_axis_size(a, axis))
|
||||
else:
|
||||
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
|
||||
if dtype is None:
|
||||
if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer):
|
||||
dtype = dtypes.float_
|
||||
else:
|
||||
dtype = dtypes.dtype(a)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
||||
return lax.div(
|
||||
sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
|
||||
lax.convert_element_type(normalizer, dtype))
|
||||
|
||||
@_wraps(np.average)
|
||||
def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
|
||||
returned=False):
|
||||
return _average(a, _ensure_optional_axes(axis), weights, returned)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'returned'), inline=True)
|
||||
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
|
||||
returned=False):
|
||||
a = _asarray(a)
|
||||
|
||||
if weights is None: # Treat all weights as 1
|
||||
avg = mean(a, axis=axis)
|
||||
if axis is None:
|
||||
weights_sum = lax.full((), core.dimension_as_value(np.size(a)), dtype=avg.dtype)
|
||||
else:
|
||||
weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype)
|
||||
else:
|
||||
weights = _asarray(weights)
|
||||
|
||||
if dtypes.issubdtype(a.dtype, np.inexact):
|
||||
out_dtype = dtypes.result_type(a.dtype, weights.dtype)
|
||||
else:
|
||||
out_dtype = dtypes.result_type(a.dtype, weights.dtype, dtypes.float_)
|
||||
out_dtype = dtypes.canonicalize_dtype(out_dtype)
|
||||
|
||||
a_shape = np.shape(a)
|
||||
a_ndim = len(a_shape)
|
||||
weights_shape = np.shape(weights)
|
||||
axis = None if axis is None else _canonicalize_axis(axis, a_ndim)
|
||||
|
||||
if a_shape != weights_shape:
|
||||
# Make sure the dimensions work out
|
||||
if axis is None:
|
||||
raise ValueError("Axis must be specified when shapes of a and "
|
||||
"weights differ.")
|
||||
if len(weights_shape) != 1:
|
||||
raise ValueError("1D weights expected when shapes of a and "
|
||||
"weights differ.")
|
||||
if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
|
||||
raise ValueError("Length of weights not "
|
||||
"compatible with specified axis.")
|
||||
|
||||
weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
|
||||
weights = _moveaxis(weights, -1, axis)
|
||||
|
||||
weights_sum = sum(weights, axis=axis, dtype=out_dtype)
|
||||
avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum
|
||||
|
||||
if returned:
|
||||
if avg.shape != weights_sum.shape:
|
||||
weights_sum = _broadcast_to(weights_sum, avg.shape)
|
||||
return avg, weights_sum
|
||||
return avg
|
||||
|
||||
|
||||
@_wraps(np.var, skip_params=['out'])
|
||||
def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
|
||||
where=where)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
_check_arraylike("var", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "var")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
|
||||
|
||||
a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
|
||||
a_mean = mean(a, axis, dtype=a_dtype, keepdims=True, where=where)
|
||||
centered = a - a_mean
|
||||
if dtypes.issubdtype(centered.dtype, np.complexfloating):
|
||||
centered = lax.real(lax.mul(centered, lax.conj(centered)))
|
||||
else:
|
||||
centered = lax.square(centered)
|
||||
|
||||
if where is None:
|
||||
if axis is None:
|
||||
normalizer = core.dimension_as_value(np.size(a))
|
||||
else:
|
||||
normalizer = core.dimension_as_value(_axis_size(a, axis))
|
||||
else:
|
||||
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
normalizer = normalizer - ddof
|
||||
|
||||
result = sum(centered, axis, keepdims=keepdims, where=where)
|
||||
out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
|
||||
return lax.convert_element_type(out, dtype)
|
||||
|
||||
|
||||
def _var_promote_types(a_dtype, dtype):
|
||||
if dtype:
|
||||
if (not dtypes.issubdtype(dtype, np.complexfloating) and
|
||||
dtypes.issubdtype(a_dtype, np.complexfloating)):
|
||||
msg = ("jax.numpy.var does not yet support real dtype parameters when "
|
||||
"computing the variance of an array of complex values. The "
|
||||
"semantics of numpy.var seem unclear in this case. Please comment "
|
||||
"on https://github.com/google/jax/issues/2283 if this behavior is "
|
||||
"important to you.")
|
||||
raise ValueError(msg)
|
||||
a_dtype = dtypes.promote_types(a_dtype, dtype)
|
||||
else:
|
||||
if not dtypes.issubdtype(a_dtype, np.inexact):
|
||||
dtype = a_dtype = dtypes.canonicalize_dtype(dtypes.float_)
|
||||
else:
|
||||
dtype = _complex_elem_type(a_dtype)
|
||||
a_dtype = dtypes.promote_types(a_dtype, np.float32)
|
||||
return a_dtype, dtype
|
||||
|
||||
|
||||
@_wraps(np.std, skip_params=['out'])
|
||||
def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
|
||||
where=where)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def _std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
_check_arraylike("std", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "std")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
|
||||
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
|
||||
|
||||
|
||||
@_wraps(np.ptp, skip_params=['out'])
|
||||
def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=False):
|
||||
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'))
|
||||
def _ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=False):
|
||||
_check_arraylike("ptp", a)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.")
|
||||
x = amax(a, axis=axis, keepdims=keepdims)
|
||||
y = amin(a, axis=axis, keepdims=keepdims)
|
||||
return lax.sub(x, y)
|
||||
|
||||
|
||||
@_wraps(np.count_nonzero)
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'))
|
||||
def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims=False):
|
||||
_check_arraylike("count_nonzero", a)
|
||||
return sum(lax.ne(a, _lax_const(a, 0)), axis=axis,
|
||||
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
|
||||
|
||||
|
||||
def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan,
|
||||
axis=None, keepdims=None, **kwargs):
|
||||
_check_arraylike(name, a)
|
||||
if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
|
||||
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
|
||||
|
||||
out = jnp_reduction(_where(isnan(a), _reduction_init_val(a, init_val), a),
|
||||
axis=axis, keepdims=keepdims, **kwargs)
|
||||
if nan_if_all_nan:
|
||||
return _where(all(isnan(a), axis=axis, keepdims=keepdims),
|
||||
_lax_const(a, np.nan), out)
|
||||
else:
|
||||
return out
|
||||
|
||||
@_wraps(np.nanmin, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'))
|
||||
def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _nan_reduction(a, 'nanmin', min, np.inf, nan_if_all_nan=initial is None,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
@_wraps(np.nanmax, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'))
|
||||
def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _nan_reduction(a, 'nanmax', max, -np.inf, nan_if_all_nan=initial is None,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
@_wraps(np.nansum, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanprod")
|
||||
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
# Work around a sphinx documentation warning in NumPy 1.22.
|
||||
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
|
||||
|
||||
@_wraps(np.nanprod, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanprod")
|
||||
return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
@_wraps(np.nanmean, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, where=None):
|
||||
_check_arraylike("nanmean", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanmean")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
|
||||
if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer):
|
||||
return mean(a, axis, dtype, out, keepdims, where=where)
|
||||
if dtype is None:
|
||||
dtype = dtypes.dtype(a)
|
||||
nan_mask = logical_not(isnan(a))
|
||||
normalizer = sum(nan_mask, axis=axis, dtype=np.int32, keepdims=keepdims, where=where)
|
||||
normalizer = lax.convert_element_type(normalizer, dtype)
|
||||
td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer)
|
||||
return td
|
||||
|
||||
|
||||
@_wraps(np.nanvar, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, where=None):
|
||||
_check_arraylike("nanvar", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanvar")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
|
||||
|
||||
a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
|
||||
a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)
|
||||
|
||||
centered = _where(isnan(a), 0, a - a_mean) # double-where trick for gradients.
|
||||
if dtypes.issubdtype(centered.dtype, np.complexfloating):
|
||||
centered = lax.real(lax.mul(centered, lax.conj(centered)))
|
||||
else:
|
||||
centered = lax.square(centered)
|
||||
|
||||
normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims, where=where)
|
||||
normalizer = normalizer - ddof
|
||||
normalizer_mask = lax.le(normalizer, 0)
|
||||
result = sum(centered, axis, keepdims=keepdims, where=where)
|
||||
result = _where(normalizer_mask, np.nan, result)
|
||||
divisor = _where(normalizer_mask, 1, normalizer)
|
||||
out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
|
||||
return lax.convert_element_type(out, dtype)
|
||||
|
||||
|
||||
@_wraps(np.nanstd, skip_params=['out'])
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, where=None):
|
||||
_check_arraylike("nanstd", a)
|
||||
lax_internal._check_user_dtype_supported(dtype, "nanstd")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
|
||||
return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
|
||||
|
||||
|
||||
def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_value=0):
|
||||
@_wraps(np_reduction, skip_params=['out'])
|
||||
def cumulative_reduction(a,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None):
|
||||
return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype'))
|
||||
def _cumulative_reduction(a,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None):
|
||||
_check_arraylike(np_reduction.__name__, a)
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
|
||||
f"is not supported.")
|
||||
lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)
|
||||
|
||||
if axis is None or _isscalar(a):
|
||||
a = lax.reshape(a, (np.size(a),))
|
||||
axis = 0
|
||||
|
||||
a_shape = list(np.shape(a))
|
||||
num_dims = len(a_shape)
|
||||
axis = _canonicalize_axis(axis, num_dims)
|
||||
|
||||
if fill_nan:
|
||||
a = _where(isnan(a), _lax_const(a, fill_value), a)
|
||||
|
||||
if not dtype and dtypes.dtype(a) == np.bool_:
|
||||
dtype = dtypes.canonicalize_dtype(dtypes.int_)
|
||||
if dtype:
|
||||
a = lax.convert_element_type(a, dtype)
|
||||
|
||||
return reduction(a, axis)
|
||||
|
||||
return cumulative_reduction
|
||||
|
||||
|
||||
cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False)
|
||||
cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False)
|
||||
cumproduct = cumprod
|
||||
nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
|
||||
fill_nan=True, fill_value=0)
|
||||
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
|
||||
fill_nan=True, fill_value=1)
|
@ -308,6 +308,38 @@ from jax._src.numpy.polynomial import (
|
||||
roots as roots,
|
||||
)
|
||||
|
||||
from jax._src.numpy.reductions import (
|
||||
alltrue as alltrue,
|
||||
amin as amin,
|
||||
amax as amax,
|
||||
any as any,
|
||||
all as all,
|
||||
average as average,
|
||||
count_nonzero as count_nonzero,
|
||||
cumsum as cumsum,
|
||||
cumprod as cumprod,
|
||||
cumproduct as cumproduct,
|
||||
max as max,
|
||||
mean as mean,
|
||||
min as min,
|
||||
nancumsum as nancumsum,
|
||||
nancumprod as nancumprod,
|
||||
nanmax as nanmax,
|
||||
nanmean as nanmean,
|
||||
nanmin as nanmin,
|
||||
nanprod as nanprod,
|
||||
nanstd as nanstd,
|
||||
nansum as nansum,
|
||||
nanvar as nanvar,
|
||||
prod as prod,
|
||||
product as product,
|
||||
ptp as ptp,
|
||||
sometrue as sometrue,
|
||||
std as std,
|
||||
sum as sum,
|
||||
var as var,
|
||||
)
|
||||
|
||||
from jax._src.numpy.ufuncs import (
|
||||
abs as abs,
|
||||
absolute as absolute,
|
||||
|
Loading…
x
Reference in New Issue
Block a user