Merge pull request #15184 from jakevdp:move-median

PiperOrigin-RevId: 519003606
This commit is contained in:
jax authors 2023-03-23 17:14:50 -07:00
commit e9bc7ee866
6 changed files with 309 additions and 310 deletions

View File

@ -33,6 +33,7 @@ from typing import NamedTuple
import jax
import jax._src.numpy.lax_numpy as jnp
import jax._src.numpy.linalg as jnp_linalg
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax import lax
from jax._src.lax import qdwh
@ -360,7 +361,7 @@ def _eigh_work(H, n, termination_size=256):
def default_case(agenda, blocks, eigenvectors):
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
# TODO: Improve this?
split_point = jnp.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan))
split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan))
H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(
H, b, split_point, V0=V)

View File

@ -1609,7 +1609,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int],
return array
stat_funcs: Dict[str, PadStatFunc] = {
"maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": median}
"maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": reductions.median}
pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width")
pad_width_arr = np.array(pad_width)
@ -4582,161 +4582,6 @@ def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -
return c
@util._wraps(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util.check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, False)
@util._wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util.check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, True)
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
"'midpoint', or 'nearest'")
a, = util.promote_dtypes_inexact(a)
keepdim = []
if issubdtype(a.dtype, np.complexfloating):
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
if axis is None:
a = ravel(a)
axis = 0
elif isinstance(axis, tuple):
keepdim = list(shape(a))
nd = ndim(a)
axis = tuple(_canonicalize_axis(ax, nd) for ax in axis)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis')
for ax in axis:
keepdim[ax] = 1
keep = set(range(nd)) - set(axis)
# prepare permutation
dimensions = list(range(nd))
for i, s in enumerate(sorted(keep)):
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
do_not_touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx not in axis)
touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx in axis)
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
axis = _canonicalize_axis(-1, ndim(a))
else:
axis = _canonicalize_axis(axis, ndim(a))
q_shape = shape(q)
q_ndim = ndim(q)
if q_ndim > 1:
raise ValueError(f"q must be have rank <= 1, got shape {shape(q)}")
a_shape = shape(a)
if squash_nans:
a = where(ufuncs.isnan(a), nan, a) # Ensure nans are positive so they sort to the end.
a = lax.sort(a, dimension=axis)
counts = reductions.sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype,
keepdims=keepdims)
shape_after_reduction = counts.shape
q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
counts = lax.expand_dims(counts, tuple(range(q_ndim)))
q = lax.mul(q, lax.sub(counts, _lax_const(q, 1)))
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
low = lax.convert_element_type(low, int64)
high = lax.convert_element_type(high, int64)
out_shape = q_shape + shape_after_reduction
index = [lax.broadcasted_iota(int64, out_shape, dim + q_ndim)
for dim in range(len(shape_after_reduction))]
if keepdims:
index[axis] = low
else:
index.insert(axis, low)
low_value = a[tuple(index)]
index[axis] = high
high_value = a[tuple(index)]
else:
a = where(reductions.any(ufuncs.isnan(a), axis=axis, keepdims=True), nan, a)
a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(array(a_shape[axis]), lax_internal._dtype(q))
q = lax.mul(q, n - 1)
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.clamp(_lax_const(low, 0), low, n - 1)
high = lax.clamp(_lax_const(high, 0), high, n - 1)
low = lax.convert_element_type(low, int64)
high = lax.convert_element_type(high, int64)
slice_sizes = list(a_shape)
slice_sizes[axis] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(range(
q_ndim,
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
collapsed_slice_dims=() if keepdims else (axis,),
start_index_map=(axis,))
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
if q_ndim == 1:
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
broadcast_dimensions=(0,))
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
broadcast_dimensions=(0,))
if interpolation == "linear":
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
lax.mul(high_value.astype(q.dtype), high_weight))
elif interpolation == "lower":
result = low_value
elif interpolation == "higher":
result = high_value
elif interpolation == "nearest":
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
result = lax.select(pred, low_value, high_value)
elif interpolation == "midpoint":
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
else:
raise ValueError(f"interpolation={interpolation!r} not recognized")
if keepdims and keepdim:
if q_ndim > 0:
keepdim = [shape(q)[0], *keepdim]
result = reshape(result, keepdim)
return lax.convert_element_type(result, a.dtype)
@partial(vectorize, excluded={0, 2, 3})
def _searchsorted_via_scan(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
op = _sort_le_comparator if side == 'left' else _sort_lt_comparator
@ -4859,50 +4704,6 @@ def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike],
return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)
@util._wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util.check_arraylike("percentile", a, q)
q, = util.promote_dtypes_inexact(q)
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method, keepdims=keepdims)
@util._wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
util.check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, float32(100.0))
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method,
keepdims=keepdims)
@util._wraps(np.median, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
util.check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, method='midpoint')
@util._wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
util.check_arraylike("nanmedian", a)
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
method='midpoint')
@util._wraps(np.place, lax_description="""
Numpy function :func:`numpy.place` is not available in JAX and will raise a

View File

@ -25,6 +25,7 @@ from jax import lax
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, _wraps)
@ -684,3 +685,201 @@ nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
fill_nan=True, fill_value=0)
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
fill_nan=True, fill_value=1)
# Quantiles
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, False)
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, True)
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
"'midpoint', or 'nearest'")
a, = promote_dtypes_inexact(a)
keepdim = []
if dtypes.issubdtype(a.dtype, np.complexfloating):
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
if axis is None:
a = a.ravel()
axis = 0
elif isinstance(axis, tuple):
keepdim = list(a.shape)
nd = a.ndim
axis = tuple(_canonicalize_axis(ax, nd) for ax in axis)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis')
for ax in axis:
keepdim[ax] = 1
keep = set(range(nd)) - set(axis)
# prepare permutation
dimensions = list(range(nd))
for i, s in enumerate(sorted(keep)):
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis)
touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis)
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
axis = _canonicalize_axis(-1, a.ndim)
else:
axis = _canonicalize_axis(axis, a.ndim)
q_shape = q.shape
q_ndim = q.ndim
if q_ndim > 1:
raise ValueError(f"q must be have rank <= 1, got shape {q.shape}")
a_shape = a.shape
if squash_nans:
a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
a = lax.sort(a, dimension=axis)
counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
shape_after_reduction = counts.shape
q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
counts = lax.expand_dims(counts, tuple(range(q_ndim)))
q = lax.mul(q, lax.sub(counts, _lax_const(q, 1)))
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
low = lax.convert_element_type(low, int)
high = lax.convert_element_type(high, int)
out_shape = q_shape + shape_after_reduction
index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim)
for dim in range(len(shape_after_reduction))]
if keepdims:
index[axis] = low
else:
index.insert(axis, low)
low_value = a[tuple(index)]
index[axis] = high
high_value = a[tuple(index)]
else:
a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a)
a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
q = lax.mul(q, n - 1)
low = lax.floor(q)
high = lax.ceil(q)
high_weight = lax.sub(q, low)
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
low = lax.clamp(_lax_const(low, 0), low, n - 1)
high = lax.clamp(_lax_const(high, 0), high, n - 1)
low = lax.convert_element_type(low, int)
high = lax.convert_element_type(high, int)
slice_sizes = list(a_shape)
slice_sizes[axis] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(range(
q_ndim,
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
collapsed_slice_dims=() if keepdims else (axis,),
start_index_map=(axis,))
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
slice_sizes=slice_sizes)
if q_ndim == 1:
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
broadcast_dimensions=(0,))
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
broadcast_dimensions=(0,))
if interpolation == "linear":
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
lax.mul(high_value.astype(q.dtype), high_weight))
elif interpolation == "lower":
result = low_value
elif interpolation == "higher":
result = high_value
elif interpolation == "nearest":
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
result = lax.select(pred, low_value, high_value)
elif interpolation == "midpoint":
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
else:
raise ValueError(f"interpolation={interpolation!r} not recognized")
if keepdims and keepdim:
if q_ndim > 0:
keepdim = [np.shape(q)[0], *keepdim]
result = result.reshape(keepdim)
return lax.convert_element_type(result, a.dtype)
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("percentile", a, q)
q, = promote_dtypes_inexact(q)
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method, keepdims=keepdims)
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, 100.0)
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method,
keepdims=keepdims)
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, method='midpoint')
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
check_arraylike("nanmedian", a)
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
method='midpoint')

View File

@ -107,6 +107,8 @@ from jax._src.numpy.lax_numpy import (
float16 as float16,
float32 as float32,
float64 as float64,
float8_e4m3fn as float8_e4m3fn,
float8_e5m2 as float8_e5m2,
float_ as float_,
floating as floating,
fmax as fmax,
@ -166,7 +168,6 @@ from jax._src.numpy.lax_numpy import (
logspace as logspace,
mask_indices as mask_indices,
matmul as matmul,
median as median,
meshgrid as meshgrid,
moveaxis as moveaxis,
msort as msort,
@ -175,9 +176,6 @@ from jax._src.numpy.lax_numpy import (
nanargmax as nanargmax,
nanargmin as nanargmin,
argpartition as argpartition,
nanmedian as nanmedian,
nanpercentile as nanpercentile,
nanquantile as nanquantile,
ndim as ndim,
newaxis as newaxis,
nonzero as nonzero,
@ -189,14 +187,12 @@ from jax._src.numpy.lax_numpy import (
packbits as packbits,
pad as pad,
partition as partition,
percentile as percentile,
pi as pi,
piecewise as piecewise,
place as place,
printoptions as printoptions,
promote_types as promote_types,
put as put,
quantile as quantile,
ravel as ravel,
ravel_multi_index as ravel_multi_index,
repeat as repeat,
@ -258,11 +254,6 @@ from jax._src.numpy.lax_numpy import (
zeros_like as zeros_like,
)
from jax._src.numpy.lax_numpy import (
float8_e4m3fn,
float8_e5m2,
)
from jax._src.numpy.index_tricks import (
c_ as c_,
index_exp as index_exp,
@ -298,19 +289,25 @@ from jax._src.numpy.reductions import (
cumproduct as cumproduct,
max as max,
mean as mean,
median as median,
min as min,
nancumsum as nancumsum,
nancumprod as nancumprod,
nanmax as nanmax,
nanmean as nanmean,
nanmedian as nanmedian,
nanmin as nanmin,
nanpercentile as nanpercentile,
nanprod as nanprod,
nanquantile as nanquantile,
nanstd as nanstd,
nansum as nansum,
nanvar as nanvar,
percentile as percentile,
prod as prod,
product as product,
ptp as ptp,
quantile as quantile,
sometrue as sometrue,
std as std,
sum as sum,

View File

@ -16,6 +16,7 @@
import collections
from functools import partial
import itertools
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@ -655,6 +656,104 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
@jtu.sample_product(
[dict(op=op, q_rng=q_rng)
for (op, q_rng) in (
("percentile", partial(jtu.rand_uniform, low=0., high=100.)),
("quantile", partial(jtu.rand_uniform, low=0., high=1.)),
("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)),
("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)),
)
],
[dict(a_shape=a_shape, axis=axis)
for a_shape, axis in (
((7,), None),
((47, 7), 0),
((47, 7), ()),
((4, 101), 1),
((4, 47, 7), (1, 2)),
((4, 47, 7), (0, 2)),
((4, 47, 7), (1, 0, 2)),
)
],
a_dtype=default_dtypes,
q_dtype=[np.float32],
q_shape=scalar_shapes + [(1,), (4,)],
keepdims=[False, True],
method=['linear', 'lower', 'higher', 'nearest', 'midpoint'],
)
def testQuantile(self, op, q_rng, a_shape, a_dtype, q_shape, q_dtype,
axis, keepdims, method):
a_rng = jtu.rand_some_nan(self.rng())
q_rng = q_rng(self.rng())
if "median" in op:
args_maker = lambda: [a_rng(a_shape, a_dtype)]
else:
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
@jtu.ignore_warning(category=RuntimeWarning,
message="All-NaN slice encountered")
def np_fun(*args):
args = [x if jnp.result_type(x) != jnp.bfloat16 else
np.asarray(x, np.float32) for x in args]
if numpy_version <= (1, 22):
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
interpolation=method)
else:
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
method=method)
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims,
method=method)
# TODO(phawkins): we currently set dtype=False because we aren't as
# aggressive about promoting to float64. It's not clear we want to mimic
# Numpy here.
tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6}
tol = max(jtu.tolerance(a_dtype, tol_spec),
jtu.tolerance(q_dtype, tol_spec))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
@unittest.skipIf(not config.jax_enable_x64, "test requires X64")
@unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision")
def testPercentilePrecision(self):
# Regression test for https://github.com/google/jax/issues/8513
x = jnp.float64([1, 2, 3, 4, 7, 10])
self.assertEqual(jnp.percentile(x, 50), 3.5)
@jtu.sample_product(
[dict(a_shape=a_shape, axis=axis)
for a_shape, axis in (
((7,), None),
((47, 7), 0),
((4, 101), 1),
)
],
a_dtype=default_dtypes,
keepdims=[False, True],
op=["median", "nanmedian"],
)
def testMedian(self, op, a_shape, a_dtype, axis, keepdims):
if op == "median":
a_rng = jtu.rand_default(self.rng())
else:
a_rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: [a_rng(a_shape, a_dtype)]
def np_fun(*args):
args = [x if jnp.result_type(x) != jnp.bfloat16 else
np.asarray(x, np.float32) for x in args]
return getattr(np, op)(*args, axis=axis, keepdims=keepdims)
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims)
# TODO(phawkins): we currently set dtype=False because we aren't as
# aggressive about promoting to float64. It's not clear we want to mimic
# Numpy here.
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
tol = jtu.tolerance(a_dtype, tol_spec)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -3912,104 +3912,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(op=op, q_rng=q_rng)
for (op, q_rng) in (
("percentile", partial(jtu.rand_uniform, low=0., high=100.)),
("quantile", partial(jtu.rand_uniform, low=0., high=1.)),
("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)),
("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)),
)
],
[dict(a_shape=a_shape, axis=axis)
for a_shape, axis in (
((7,), None),
((47, 7), 0),
((47, 7), ()),
((4, 101), 1),
((4, 47, 7), (1, 2)),
((4, 47, 7), (0, 2)),
((4, 47, 7), (1, 0, 2)),
)
],
a_dtype=default_dtypes,
q_dtype=[np.float32],
q_shape=scalar_shapes + [(1,), (4,)],
keepdims=[False, True],
method=['linear', 'lower', 'higher', 'nearest', 'midpoint'],
)
def testQuantile(self, op, q_rng, a_shape, a_dtype, q_shape, q_dtype,
axis, keepdims, method):
a_rng = jtu.rand_some_nan(self.rng())
q_rng = q_rng(self.rng())
if "median" in op:
args_maker = lambda: [a_rng(a_shape, a_dtype)]
else:
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
@jtu.ignore_warning(category=RuntimeWarning,
message="All-NaN slice encountered")
def np_fun(*args):
args = [x if jnp.result_type(x) != jnp.bfloat16 else
np.asarray(x, np.float32) for x in args]
if numpy_version <= (1, 22):
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
interpolation=method)
else:
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
method=method)
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims,
method=method)
# TODO(phawkins): we currently set dtype=False because we aren't as
# aggressive about promoting to float64. It's not clear we want to mimic
# Numpy here.
tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6}
tol = max(jtu.tolerance(a_dtype, tol_spec),
jtu.tolerance(q_dtype, tol_spec))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
@unittest.skipIf(not config.jax_enable_x64, "test requires X64")
@unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision")
def testPercentilePrecision(self):
# Regression test for https://github.com/google/jax/issues/8513
x = jnp.float64([1, 2, 3, 4, 7, 10])
self.assertEqual(jnp.percentile(x, 50), 3.5)
@jtu.sample_product(
[dict(a_shape=a_shape, axis=axis)
for a_shape, axis in (
((7,), None),
((47, 7), 0),
((4, 101), 1),
)
],
a_dtype=default_dtypes,
keepdims=[False, True],
op=["median", "nanmedian"],
)
def testMedian(self, op, a_shape, a_dtype, axis, keepdims):
if op == "median":
a_rng = jtu.rand_default(self.rng())
else:
a_rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: [a_rng(a_shape, a_dtype)]
def np_fun(*args):
args = [x if jnp.result_type(x) != jnp.bfloat16 else
np.asarray(x, np.float32) for x in args]
return getattr(np, op)(*args, axis=axis, keepdims=keepdims)
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims)
# TODO(phawkins): we currently set dtype=False because we aren't as
# aggressive about promoting to float64. It's not clear we want to mimic
# Numpy here.
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
tol = jtu.tolerance(a_dtype, tol_spec)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
@jtu.sample_product(
shape=all_shapes,
dtype=all_dtypes,