Fix test failures under Numpy 1.22.

This commit is contained in:
Peter Hawkins 2022-01-04 11:47:54 -05:00
parent 2de145374c
commit 3c193613ce
4 changed files with 68 additions and 31 deletions

View File

@ -136,6 +136,12 @@ def coerce_to_array(x, dtype=None):
iinfo = np.iinfo
class _Bfloat16MachArLike:
def __init__(self):
smallest_normal = float.fromhex("0x1p-126")
self.smallest_normal = bfloat16(smallest_normal)
class finfo(np.finfo):
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[np.dtype, np.finfo] = {}
@ -165,10 +171,12 @@ class finfo(np.finfo):
obj.iexp = obj.nexp
obj.precision = 2
obj.resolution = bfloat16(resolution)
obj.tiny = bfloat16(tiny)
obj.machar = None # np.core.getlimits.MachArLike does not support bfloat16.
obj._machar = _Bfloat16MachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = bfloat16(tiny)
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_max = float_to_str(max)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)

View File

@ -2910,11 +2910,14 @@ def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
@_wraps(np.nansum, skip_params=['out'])
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None):
out=None, keepdims=None, initial=None, where=None):
lax._check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims)
# Work around a sphinx documentation warning in NumPy 1.22.
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
@_wraps(np.nanprod, skip_params=['out'])
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
@ -6322,28 +6325,35 @@ def corrcoef(x, y=None, rowvar=True):
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims'))
'keepdims', 'method'))
def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
overwrite_input=False, interpolation="linear", keepdims=False):
overwrite_input=False, method="linear", keepdims=False,
interpolation=None):
_check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
return _quantile(a, q, axis, interpolation, keepdims, False)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(a, q, axis, interpolation or method, keepdims, False)
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims'))
'keepdims', 'method'))
def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, interpolation="linear",
keepdims=False):
out=None, overwrite_input=False, method="linear",
keepdims=False, interpolation=None):
_check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
return _quantile(a, q, axis, interpolation, keepdims, True)
if interpolation is not None:
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(a, q, axis, interpolation or method, keepdims, True)
def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
@ -6531,26 +6541,27 @@ def _piecewise(x, condlist, consts, funcs, *args, **kw):
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims'))
'keepdims', 'method'))
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, interpolation="linear",
keepdims=False):
out=None, overwrite_input=False, method="linear",
keepdims=False, interpolation=None):
_check_arraylike("percentile", a, q)
a, q = _promote_dtypes_inexact(a, q)
q = true_divide(q, 100.0)
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)
interpolation=interpolation, method=method, keepdims=keepdims)
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims'))
'keepdims', 'method'))
def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, interpolation="linear",
keepdims=False):
out=None, overwrite_input=False, method="linear",
keepdims=False, interpolation=None):
_check_arraylike("nanpercentile", a, q)
q = true_divide(q, float32(100.0))
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)
interpolation=interpolation, method=method,
keepdims=keepdims)
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
@ -6558,7 +6569,7 @@ def median(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
overwrite_input=False, keepdims=False):
_check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, interpolation='midpoint')
keepdims=keepdims, method='midpoint')
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
@ -6567,7 +6578,7 @@ def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
_check_arraylike("nanmedian", a)
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
interpolation='midpoint')
method='midpoint')
def _astype(arr, dtype):

View File

@ -14,5 +14,7 @@ filterwarnings =
ignore:.*experimental feature
ignore:index.*is deprecated.*:DeprecationWarning
ignore:jax.experimental.* is deprecated, import jax.example_libraries.* instead:FutureWarning
# numpy uses distutils which is deprecated
ignore:The distutils.* is deprecated.*:DeprecationWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"

View File

@ -235,8 +235,8 @@ JAX_COMPOUND_OP_RECORDS = [
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("fix", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("floor_divide", 2, number_dtypes, all_shapes,
jtu.rand_nonzero, ["rev"]),
op_record("floor_divide", 2, float_dtypes + int_dtypes,
all_shapes, jtu.rand_nonzero, ["rev"]),
op_record("floor_divide", 2, unsigned_dtypes, all_shapes,
jtu.rand_nonzero, ["rev"]),
op_record("fmin", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
@ -4430,17 +4430,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}_interpolation={}".format(
"_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}_method={}".format(
op,
jtu.format_shape_dtype_string(a_shape, a_dtype),
jtu.format_shape_dtype_string(q_shape, q_dtype),
axis, keepdims, interpolation),
axis, keepdims, method),
"a_rng": jtu.rand_some_nan,
"q_rng": q_rng, "op": op,
"a_shape": a_shape, "a_dtype": a_dtype,
"q_shape": q_shape, "q_dtype": q_dtype, "axis": axis,
"keepdims": keepdims,
"interpolation": interpolation}
"method": method}
for (op, q_rng) in (
("percentile", partial(jtu.rand_uniform, low=0., high=100.)),
("quantile", partial(jtu.rand_uniform, low=0., high=1.)),
@ -4456,10 +4456,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for q_dtype in [np.float32]
for q_shape in scalar_shapes + [(4,)]
for keepdims in [False, True]
for interpolation in ['linear', 'lower', 'higher', 'nearest',
'midpoint']))
for method in ['linear', 'lower', 'higher', 'nearest', 'midpoint']))
def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype,
axis, keepdims, interpolation):
axis, keepdims, method):
a_rng = a_rng(self.rng())
q_rng = q_rng(self.rng())
if "median" in op:
@ -4470,10 +4469,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def np_fun(*args):
args = [x if jnp.result_type(x) != jnp.bfloat16 else
np.asarray(x, np.float32) for x in args]
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
interpolation=interpolation)
if numpy_version <= (1, 22):
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
interpolation=method)
else:
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
method=method)
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims,
interpolation=interpolation)
method=method)
# TODO(phawkins): we currently set dtype=False because we aren't as
# aggressive about promoting to float64. It's not clear we want to mimic
@ -5897,6 +5900,8 @@ class NumpySignaturesTest(jtu.JaxTestCase):
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
unsupported_params = {
'angle': ['deg'],
'argmax': ['keepdims'],
'argmin': ['keepdims'],
'asarray': ['like'],
'broadcast_to': ['subok', 'array'],
'clip': ['kwargs'],
@ -5912,6 +5917,14 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'histogram': ['normed'],
'histogram2d': ['normed'],
'histogramdd': ['normed'],
'nanargmax': ['out', 'keepdims'],
'nanargmin': ['out', 'keepdims'],
'nanmax': ['initial', 'where'],
'nanmean': ['where'],
'nanmin': ['initial', 'where'],
'nanprod': ['initial', 'where'],
'nanstd': ['where'],
'nanvar': ['where'],
'ones': ['order', 'like'],
'ones_like': ['subok', 'order'],
'tri': ['like'],
@ -5934,6 +5947,9 @@ class NumpySignaturesTest(jtu.JaxTestCase):
# Some signatures have changed; skip for older numpy versions.
if numpy_version < (1, 19) and name in ['einsum_path', 'gradient', 'isscalar']:
continue
if numpy_version < (1, 22) and name in ['quantile', 'nanquantile',
'percentile', 'nanpercentile']:
continue
# Note: can't use inspect.getfullargspec due to numpy issue
# https://github.com/numpy/numpy/issues/12225
try: