mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #15184 from jakevdp:move-median
PiperOrigin-RevId: 519003606
This commit is contained in:
commit
e9bc7ee866
@ -33,6 +33,7 @@ from typing import NamedTuple
|
||||
import jax
|
||||
import jax._src.numpy.lax_numpy as jnp
|
||||
import jax._src.numpy.linalg as jnp_linalg
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax import lax
|
||||
from jax._src.lax import qdwh
|
||||
@ -360,7 +361,7 @@ def _eigh_work(H, n, termination_size=256):
|
||||
def default_case(agenda, blocks, eigenvectors):
|
||||
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
|
||||
# TODO: Improve this?
|
||||
split_point = jnp.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan))
|
||||
split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan))
|
||||
H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(
|
||||
H, b, split_point, V0=V)
|
||||
|
||||
|
@ -1609,7 +1609,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int],
|
||||
return array
|
||||
|
||||
stat_funcs: Dict[str, PadStatFunc] = {
|
||||
"maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": median}
|
||||
"maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": reductions.median}
|
||||
|
||||
pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width")
|
||||
pad_width_arr = np.array(pad_width)
|
||||
@ -4582,161 +4582,6 @@ def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -
|
||||
return c
|
||||
|
||||
|
||||
@util._wraps(np.quantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
util.check_arraylike("quantile", a, q)
|
||||
if overwrite_input or out is not None:
|
||||
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
|
||||
"out != None")
|
||||
raise ValueError(msg)
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, False)
|
||||
|
||||
@util._wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
util.check_arraylike("nanquantile", a, q)
|
||||
if overwrite_input or out is not None:
|
||||
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
|
||||
"out != None")
|
||||
raise ValueError(msg)
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, True)
|
||||
|
||||
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
|
||||
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
|
||||
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
|
||||
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
|
||||
"'midpoint', or 'nearest'")
|
||||
a, = util.promote_dtypes_inexact(a)
|
||||
keepdim = []
|
||||
if issubdtype(a.dtype, np.complexfloating):
|
||||
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
elif isinstance(axis, tuple):
|
||||
keepdim = list(shape(a))
|
||||
nd = ndim(a)
|
||||
axis = tuple(_canonicalize_axis(ax, nd) for ax in axis)
|
||||
if len(set(axis)) != len(axis):
|
||||
raise ValueError('repeated axis')
|
||||
for ax in axis:
|
||||
keepdim[ax] = 1
|
||||
|
||||
keep = set(range(nd)) - set(axis)
|
||||
# prepare permutation
|
||||
dimensions = list(range(nd))
|
||||
for i, s in enumerate(sorted(keep)):
|
||||
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
|
||||
do_not_touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx not in axis)
|
||||
touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx in axis)
|
||||
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
|
||||
axis = _canonicalize_axis(-1, ndim(a))
|
||||
else:
|
||||
axis = _canonicalize_axis(axis, ndim(a))
|
||||
|
||||
q_shape = shape(q)
|
||||
q_ndim = ndim(q)
|
||||
if q_ndim > 1:
|
||||
raise ValueError(f"q must be have rank <= 1, got shape {shape(q)}")
|
||||
|
||||
a_shape = shape(a)
|
||||
|
||||
if squash_nans:
|
||||
a = where(ufuncs.isnan(a), nan, a) # Ensure nans are positive so they sort to the end.
|
||||
a = lax.sort(a, dimension=axis)
|
||||
counts = reductions.sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype,
|
||||
keepdims=keepdims)
|
||||
shape_after_reduction = counts.shape
|
||||
q = lax.expand_dims(
|
||||
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
|
||||
counts = lax.expand_dims(counts, tuple(range(q_ndim)))
|
||||
q = lax.mul(q, lax.sub(counts, _lax_const(q, 1)))
|
||||
low = lax.floor(q)
|
||||
high = lax.ceil(q)
|
||||
high_weight = lax.sub(q, low)
|
||||
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
|
||||
|
||||
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
|
||||
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
|
||||
low = lax.convert_element_type(low, int64)
|
||||
high = lax.convert_element_type(high, int64)
|
||||
out_shape = q_shape + shape_after_reduction
|
||||
index = [lax.broadcasted_iota(int64, out_shape, dim + q_ndim)
|
||||
for dim in range(len(shape_after_reduction))]
|
||||
if keepdims:
|
||||
index[axis] = low
|
||||
else:
|
||||
index.insert(axis, low)
|
||||
low_value = a[tuple(index)]
|
||||
index[axis] = high
|
||||
high_value = a[tuple(index)]
|
||||
else:
|
||||
a = where(reductions.any(ufuncs.isnan(a), axis=axis, keepdims=True), nan, a)
|
||||
a = lax.sort(a, dimension=axis)
|
||||
n = lax.convert_element_type(array(a_shape[axis]), lax_internal._dtype(q))
|
||||
q = lax.mul(q, n - 1)
|
||||
low = lax.floor(q)
|
||||
high = lax.ceil(q)
|
||||
high_weight = lax.sub(q, low)
|
||||
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
|
||||
|
||||
low = lax.clamp(_lax_const(low, 0), low, n - 1)
|
||||
high = lax.clamp(_lax_const(high, 0), high, n - 1)
|
||||
low = lax.convert_element_type(low, int64)
|
||||
high = lax.convert_element_type(high, int64)
|
||||
|
||||
slice_sizes = list(a_shape)
|
||||
slice_sizes[axis] = 1
|
||||
dnums = lax.GatherDimensionNumbers(
|
||||
offset_dims=tuple(range(
|
||||
q_ndim,
|
||||
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
|
||||
collapsed_slice_dims=() if keepdims else (axis,),
|
||||
start_index_map=(axis,))
|
||||
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes)
|
||||
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes)
|
||||
if q_ndim == 1:
|
||||
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
|
||||
broadcast_dimensions=(0,))
|
||||
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
|
||||
broadcast_dimensions=(0,))
|
||||
|
||||
if interpolation == "linear":
|
||||
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
|
||||
lax.mul(high_value.astype(q.dtype), high_weight))
|
||||
elif interpolation == "lower":
|
||||
result = low_value
|
||||
elif interpolation == "higher":
|
||||
result = high_value
|
||||
elif interpolation == "nearest":
|
||||
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
|
||||
result = lax.select(pred, low_value, high_value)
|
||||
elif interpolation == "midpoint":
|
||||
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
|
||||
else:
|
||||
raise ValueError(f"interpolation={interpolation!r} not recognized")
|
||||
if keepdims and keepdim:
|
||||
if q_ndim > 0:
|
||||
keepdim = [shape(q)[0], *keepdim]
|
||||
result = reshape(result, keepdim)
|
||||
return lax.convert_element_type(result, a.dtype)
|
||||
|
||||
|
||||
@partial(vectorize, excluded={0, 2, 3})
|
||||
def _searchsorted_via_scan(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
|
||||
op = _sort_le_comparator if side == 'left' else _sort_lt_comparator
|
||||
@ -4859,50 +4704,6 @@ def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike],
|
||||
return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)
|
||||
|
||||
|
||||
@util._wraps(np.percentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def percentile(a: ArrayLike, q: ArrayLike,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
util.check_arraylike("percentile", a, q)
|
||||
q, = util.promote_dtypes_inexact(q)
|
||||
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, method=method, keepdims=keepdims)
|
||||
|
||||
@util._wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanpercentile(a: ArrayLike, q: ArrayLike,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
util.check_arraylike("nanpercentile", a, q)
|
||||
q = ufuncs.true_divide(q, float32(100.0))
|
||||
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, method=method,
|
||||
keepdims=keepdims)
|
||||
|
||||
@util._wraps(np.median, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
keepdims: bool = False) -> Array:
|
||||
util.check_arraylike("median", a)
|
||||
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
keepdims=keepdims, method='midpoint')
|
||||
|
||||
@util._wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
keepdims: bool = False) -> Array:
|
||||
util.check_arraylike("nanmedian", a)
|
||||
return nanquantile(a, 0.5, axis=axis, out=out,
|
||||
overwrite_input=overwrite_input, keepdims=keepdims,
|
||||
method='midpoint')
|
||||
|
||||
|
||||
@util._wraps(np.place, lax_description="""
|
||||
Numpy function :func:`numpy.place` is not available in JAX and will raise a
|
||||
|
@ -25,6 +25,7 @@ from jax import lax
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy.util import (
|
||||
_broadcast_to, check_arraylike, _complex_elem_type,
|
||||
promote_dtypes_inexact, promote_dtypes_numeric, _where, _wraps)
|
||||
@ -684,3 +685,201 @@ nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
|
||||
fill_nan=True, fill_value=0)
|
||||
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
|
||||
fill_nan=True, fill_value=1)
|
||||
|
||||
# Quantiles
|
||||
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
check_arraylike("quantile", a, q)
|
||||
if overwrite_input or out is not None:
|
||||
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
|
||||
"out != None")
|
||||
raise ValueError(msg)
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, False)
|
||||
|
||||
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
check_arraylike("nanquantile", a, q)
|
||||
if overwrite_input or out is not None:
|
||||
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
|
||||
"out != None")
|
||||
raise ValueError(msg)
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, True)
|
||||
|
||||
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
|
||||
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
|
||||
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
|
||||
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
|
||||
"'midpoint', or 'nearest'")
|
||||
a, = promote_dtypes_inexact(a)
|
||||
keepdim = []
|
||||
if dtypes.issubdtype(a.dtype, np.complexfloating):
|
||||
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
|
||||
if axis is None:
|
||||
a = a.ravel()
|
||||
axis = 0
|
||||
elif isinstance(axis, tuple):
|
||||
keepdim = list(a.shape)
|
||||
nd = a.ndim
|
||||
axis = tuple(_canonicalize_axis(ax, nd) for ax in axis)
|
||||
if len(set(axis)) != len(axis):
|
||||
raise ValueError('repeated axis')
|
||||
for ax in axis:
|
||||
keepdim[ax] = 1
|
||||
|
||||
keep = set(range(nd)) - set(axis)
|
||||
# prepare permutation
|
||||
dimensions = list(range(nd))
|
||||
for i, s in enumerate(sorted(keep)):
|
||||
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
|
||||
do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis)
|
||||
touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis)
|
||||
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
|
||||
axis = _canonicalize_axis(-1, a.ndim)
|
||||
else:
|
||||
axis = _canonicalize_axis(axis, a.ndim)
|
||||
|
||||
q_shape = q.shape
|
||||
q_ndim = q.ndim
|
||||
if q_ndim > 1:
|
||||
raise ValueError(f"q must be have rank <= 1, got shape {q.shape}")
|
||||
|
||||
a_shape = a.shape
|
||||
|
||||
if squash_nans:
|
||||
a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
|
||||
a = lax.sort(a, dimension=axis)
|
||||
counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
|
||||
shape_after_reduction = counts.shape
|
||||
q = lax.expand_dims(
|
||||
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
|
||||
counts = lax.expand_dims(counts, tuple(range(q_ndim)))
|
||||
q = lax.mul(q, lax.sub(counts, _lax_const(q, 1)))
|
||||
low = lax.floor(q)
|
||||
high = lax.ceil(q)
|
||||
high_weight = lax.sub(q, low)
|
||||
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
|
||||
|
||||
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
|
||||
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
|
||||
low = lax.convert_element_type(low, int)
|
||||
high = lax.convert_element_type(high, int)
|
||||
out_shape = q_shape + shape_after_reduction
|
||||
index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim)
|
||||
for dim in range(len(shape_after_reduction))]
|
||||
if keepdims:
|
||||
index[axis] = low
|
||||
else:
|
||||
index.insert(axis, low)
|
||||
low_value = a[tuple(index)]
|
||||
index[axis] = high
|
||||
high_value = a[tuple(index)]
|
||||
else:
|
||||
a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a)
|
||||
a = lax.sort(a, dimension=axis)
|
||||
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
|
||||
q = lax.mul(q, n - 1)
|
||||
low = lax.floor(q)
|
||||
high = lax.ceil(q)
|
||||
high_weight = lax.sub(q, low)
|
||||
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
|
||||
|
||||
low = lax.clamp(_lax_const(low, 0), low, n - 1)
|
||||
high = lax.clamp(_lax_const(high, 0), high, n - 1)
|
||||
low = lax.convert_element_type(low, int)
|
||||
high = lax.convert_element_type(high, int)
|
||||
|
||||
slice_sizes = list(a_shape)
|
||||
slice_sizes[axis] = 1
|
||||
dnums = lax.GatherDimensionNumbers(
|
||||
offset_dims=tuple(range(
|
||||
q_ndim,
|
||||
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
|
||||
collapsed_slice_dims=() if keepdims else (axis,),
|
||||
start_index_map=(axis,))
|
||||
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes)
|
||||
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes)
|
||||
if q_ndim == 1:
|
||||
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
|
||||
broadcast_dimensions=(0,))
|
||||
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
|
||||
broadcast_dimensions=(0,))
|
||||
|
||||
if interpolation == "linear":
|
||||
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
|
||||
lax.mul(high_value.astype(q.dtype), high_weight))
|
||||
elif interpolation == "lower":
|
||||
result = low_value
|
||||
elif interpolation == "higher":
|
||||
result = high_value
|
||||
elif interpolation == "nearest":
|
||||
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
|
||||
result = lax.select(pred, low_value, high_value)
|
||||
elif interpolation == "midpoint":
|
||||
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
|
||||
else:
|
||||
raise ValueError(f"interpolation={interpolation!r} not recognized")
|
||||
if keepdims and keepdim:
|
||||
if q_ndim > 0:
|
||||
keepdim = [np.shape(q)[0], *keepdim]
|
||||
result = result.reshape(keepdim)
|
||||
return lax.convert_element_type(result, a.dtype)
|
||||
|
||||
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def percentile(a: ArrayLike, q: ArrayLike,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
check_arraylike("percentile", a, q)
|
||||
q, = promote_dtypes_inexact(q)
|
||||
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, method=method, keepdims=keepdims)
|
||||
|
||||
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims', 'method'))
|
||||
def nanpercentile(a: ArrayLike, q: ArrayLike,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False, method: str = "linear",
|
||||
keepdims: bool = False, interpolation: None = None) -> Array:
|
||||
check_arraylike("nanpercentile", a, q)
|
||||
q = ufuncs.true_divide(q, 100.0)
|
||||
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, method=method,
|
||||
keepdims=keepdims)
|
||||
|
||||
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
keepdims: bool = False) -> Array:
|
||||
check_arraylike("median", a)
|
||||
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
keepdims=keepdims, method='midpoint')
|
||||
|
||||
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out: None = None, overwrite_input: bool = False,
|
||||
keepdims: bool = False) -> Array:
|
||||
check_arraylike("nanmedian", a)
|
||||
return nanquantile(a, 0.5, axis=axis, out=out,
|
||||
overwrite_input=overwrite_input, keepdims=keepdims,
|
||||
method='midpoint')
|
||||
|
@ -107,6 +107,8 @@ from jax._src.numpy.lax_numpy import (
|
||||
float16 as float16,
|
||||
float32 as float32,
|
||||
float64 as float64,
|
||||
float8_e4m3fn as float8_e4m3fn,
|
||||
float8_e5m2 as float8_e5m2,
|
||||
float_ as float_,
|
||||
floating as floating,
|
||||
fmax as fmax,
|
||||
@ -166,7 +168,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
logspace as logspace,
|
||||
mask_indices as mask_indices,
|
||||
matmul as matmul,
|
||||
median as median,
|
||||
meshgrid as meshgrid,
|
||||
moveaxis as moveaxis,
|
||||
msort as msort,
|
||||
@ -175,9 +176,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
nanargmax as nanargmax,
|
||||
nanargmin as nanargmin,
|
||||
argpartition as argpartition,
|
||||
nanmedian as nanmedian,
|
||||
nanpercentile as nanpercentile,
|
||||
nanquantile as nanquantile,
|
||||
ndim as ndim,
|
||||
newaxis as newaxis,
|
||||
nonzero as nonzero,
|
||||
@ -189,14 +187,12 @@ from jax._src.numpy.lax_numpy import (
|
||||
packbits as packbits,
|
||||
pad as pad,
|
||||
partition as partition,
|
||||
percentile as percentile,
|
||||
pi as pi,
|
||||
piecewise as piecewise,
|
||||
place as place,
|
||||
printoptions as printoptions,
|
||||
promote_types as promote_types,
|
||||
put as put,
|
||||
quantile as quantile,
|
||||
ravel as ravel,
|
||||
ravel_multi_index as ravel_multi_index,
|
||||
repeat as repeat,
|
||||
@ -258,11 +254,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
zeros_like as zeros_like,
|
||||
)
|
||||
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
float8_e4m3fn,
|
||||
float8_e5m2,
|
||||
)
|
||||
|
||||
from jax._src.numpy.index_tricks import (
|
||||
c_ as c_,
|
||||
index_exp as index_exp,
|
||||
@ -298,19 +289,25 @@ from jax._src.numpy.reductions import (
|
||||
cumproduct as cumproduct,
|
||||
max as max,
|
||||
mean as mean,
|
||||
median as median,
|
||||
min as min,
|
||||
nancumsum as nancumsum,
|
||||
nancumprod as nancumprod,
|
||||
nanmax as nanmax,
|
||||
nanmean as nanmean,
|
||||
nanmedian as nanmedian,
|
||||
nanmin as nanmin,
|
||||
nanpercentile as nanpercentile,
|
||||
nanprod as nanprod,
|
||||
nanquantile as nanquantile,
|
||||
nanstd as nanstd,
|
||||
nansum as nansum,
|
||||
nanvar as nanvar,
|
||||
percentile as percentile,
|
||||
prod as prod,
|
||||
product as product,
|
||||
ptp as ptp,
|
||||
quantile as quantile,
|
||||
sometrue as sometrue,
|
||||
std as std,
|
||||
sum as sum,
|
||||
|
@ -16,6 +16,7 @@
|
||||
import collections
|
||||
from functools import partial
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -655,6 +656,104 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(op=op, q_rng=q_rng)
|
||||
for (op, q_rng) in (
|
||||
("percentile", partial(jtu.rand_uniform, low=0., high=100.)),
|
||||
("quantile", partial(jtu.rand_uniform, low=0., high=1.)),
|
||||
("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)),
|
||||
("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)),
|
||||
)
|
||||
],
|
||||
[dict(a_shape=a_shape, axis=axis)
|
||||
for a_shape, axis in (
|
||||
((7,), None),
|
||||
((47, 7), 0),
|
||||
((47, 7), ()),
|
||||
((4, 101), 1),
|
||||
((4, 47, 7), (1, 2)),
|
||||
((4, 47, 7), (0, 2)),
|
||||
((4, 47, 7), (1, 0, 2)),
|
||||
)
|
||||
],
|
||||
a_dtype=default_dtypes,
|
||||
q_dtype=[np.float32],
|
||||
q_shape=scalar_shapes + [(1,), (4,)],
|
||||
keepdims=[False, True],
|
||||
method=['linear', 'lower', 'higher', 'nearest', 'midpoint'],
|
||||
)
|
||||
def testQuantile(self, op, q_rng, a_shape, a_dtype, q_shape, q_dtype,
|
||||
axis, keepdims, method):
|
||||
a_rng = jtu.rand_some_nan(self.rng())
|
||||
q_rng = q_rng(self.rng())
|
||||
if "median" in op:
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype)]
|
||||
else:
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
|
||||
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="All-NaN slice encountered")
|
||||
def np_fun(*args):
|
||||
args = [x if jnp.result_type(x) != jnp.bfloat16 else
|
||||
np.asarray(x, np.float32) for x in args]
|
||||
if numpy_version <= (1, 22):
|
||||
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
|
||||
interpolation=method)
|
||||
else:
|
||||
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
|
||||
method=method)
|
||||
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims,
|
||||
method=method)
|
||||
|
||||
# TODO(phawkins): we currently set dtype=False because we aren't as
|
||||
# aggressive about promoting to float64. It's not clear we want to mimic
|
||||
# Numpy here.
|
||||
tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6}
|
||||
tol = max(jtu.tolerance(a_dtype, tol_spec),
|
||||
jtu.tolerance(q_dtype, tol_spec))
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
|
||||
|
||||
@unittest.skipIf(not config.jax_enable_x64, "test requires X64")
|
||||
@unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision")
|
||||
def testPercentilePrecision(self):
|
||||
# Regression test for https://github.com/google/jax/issues/8513
|
||||
x = jnp.float64([1, 2, 3, 4, 7, 10])
|
||||
self.assertEqual(jnp.percentile(x, 50), 3.5)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(a_shape=a_shape, axis=axis)
|
||||
for a_shape, axis in (
|
||||
((7,), None),
|
||||
((47, 7), 0),
|
||||
((4, 101), 1),
|
||||
)
|
||||
],
|
||||
a_dtype=default_dtypes,
|
||||
keepdims=[False, True],
|
||||
op=["median", "nanmedian"],
|
||||
)
|
||||
def testMedian(self, op, a_shape, a_dtype, axis, keepdims):
|
||||
if op == "median":
|
||||
a_rng = jtu.rand_default(self.rng())
|
||||
else:
|
||||
a_rng = jtu.rand_some_nan(self.rng())
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype)]
|
||||
def np_fun(*args):
|
||||
args = [x if jnp.result_type(x) != jnp.bfloat16 else
|
||||
np.asarray(x, np.float32) for x in args]
|
||||
return getattr(np, op)(*args, axis=axis, keepdims=keepdims)
|
||||
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims)
|
||||
# TODO(phawkins): we currently set dtype=False because we aren't as
|
||||
# aggressive about promoting to float64. It's not clear we want to mimic
|
||||
# Numpy here.
|
||||
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
|
||||
tol = jtu.tolerance(a_dtype, tol_spec)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -3912,104 +3912,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(op=op, q_rng=q_rng)
|
||||
for (op, q_rng) in (
|
||||
("percentile", partial(jtu.rand_uniform, low=0., high=100.)),
|
||||
("quantile", partial(jtu.rand_uniform, low=0., high=1.)),
|
||||
("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)),
|
||||
("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)),
|
||||
)
|
||||
],
|
||||
[dict(a_shape=a_shape, axis=axis)
|
||||
for a_shape, axis in (
|
||||
((7,), None),
|
||||
((47, 7), 0),
|
||||
((47, 7), ()),
|
||||
((4, 101), 1),
|
||||
((4, 47, 7), (1, 2)),
|
||||
((4, 47, 7), (0, 2)),
|
||||
((4, 47, 7), (1, 0, 2)),
|
||||
)
|
||||
],
|
||||
a_dtype=default_dtypes,
|
||||
q_dtype=[np.float32],
|
||||
q_shape=scalar_shapes + [(1,), (4,)],
|
||||
keepdims=[False, True],
|
||||
method=['linear', 'lower', 'higher', 'nearest', 'midpoint'],
|
||||
)
|
||||
def testQuantile(self, op, q_rng, a_shape, a_dtype, q_shape, q_dtype,
|
||||
axis, keepdims, method):
|
||||
a_rng = jtu.rand_some_nan(self.rng())
|
||||
q_rng = q_rng(self.rng())
|
||||
if "median" in op:
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype)]
|
||||
else:
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
|
||||
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="All-NaN slice encountered")
|
||||
def np_fun(*args):
|
||||
args = [x if jnp.result_type(x) != jnp.bfloat16 else
|
||||
np.asarray(x, np.float32) for x in args]
|
||||
if numpy_version <= (1, 22):
|
||||
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
|
||||
interpolation=method)
|
||||
else:
|
||||
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
|
||||
method=method)
|
||||
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims,
|
||||
method=method)
|
||||
|
||||
# TODO(phawkins): we currently set dtype=False because we aren't as
|
||||
# aggressive about promoting to float64. It's not clear we want to mimic
|
||||
# Numpy here.
|
||||
tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6}
|
||||
tol = max(jtu.tolerance(a_dtype, tol_spec),
|
||||
jtu.tolerance(q_dtype, tol_spec))
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
|
||||
|
||||
@unittest.skipIf(not config.jax_enable_x64, "test requires X64")
|
||||
@unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision")
|
||||
def testPercentilePrecision(self):
|
||||
# Regression test for https://github.com/google/jax/issues/8513
|
||||
x = jnp.float64([1, 2, 3, 4, 7, 10])
|
||||
self.assertEqual(jnp.percentile(x, 50), 3.5)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(a_shape=a_shape, axis=axis)
|
||||
for a_shape, axis in (
|
||||
((7,), None),
|
||||
((47, 7), 0),
|
||||
((4, 101), 1),
|
||||
)
|
||||
],
|
||||
a_dtype=default_dtypes,
|
||||
keepdims=[False, True],
|
||||
op=["median", "nanmedian"],
|
||||
)
|
||||
def testMedian(self, op, a_shape, a_dtype, axis, keepdims):
|
||||
if op == "median":
|
||||
a_rng = jtu.rand_default(self.rng())
|
||||
else:
|
||||
a_rng = jtu.rand_some_nan(self.rng())
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype)]
|
||||
def np_fun(*args):
|
||||
args = [x if jnp.result_type(x) != jnp.bfloat16 else
|
||||
np.asarray(x, np.float32) for x in args]
|
||||
return getattr(np, op)(*args, axis=axis, keepdims=keepdims)
|
||||
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims)
|
||||
# TODO(phawkins): we currently set dtype=False because we aren't as
|
||||
# aggressive about promoting to float64. It's not clear we want to mimic
|
||||
# Numpy here.
|
||||
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
|
||||
tol = jtu.tolerance(a_dtype, tol_spec)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=all_dtypes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user