mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix test failures under Numpy 1.22.
This commit is contained in:
parent
2de145374c
commit
3c193613ce
@ -136,6 +136,12 @@ def coerce_to_array(x, dtype=None):
|
||||
|
||||
iinfo = np.iinfo
|
||||
|
||||
class _Bfloat16MachArLike:
|
||||
def __init__(self):
|
||||
smallest_normal = float.fromhex("0x1p-126")
|
||||
self.smallest_normal = bfloat16(smallest_normal)
|
||||
|
||||
|
||||
class finfo(np.finfo):
|
||||
__doc__ = np.finfo.__doc__
|
||||
_finfo_cache: Dict[np.dtype, np.finfo] = {}
|
||||
@ -165,10 +171,12 @@ class finfo(np.finfo):
|
||||
obj.iexp = obj.nexp
|
||||
obj.precision = 2
|
||||
obj.resolution = bfloat16(resolution)
|
||||
obj.tiny = bfloat16(tiny)
|
||||
obj.machar = None # np.core.getlimits.MachArLike does not support bfloat16.
|
||||
obj._machar = _Bfloat16MachArLike()
|
||||
if not hasattr(obj, "tiny"):
|
||||
obj.tiny = bfloat16(tiny)
|
||||
|
||||
obj._str_tiny = float_to_str(tiny)
|
||||
obj._str_smallest_normal = float_to_str(tiny)
|
||||
obj._str_max = float_to_str(max)
|
||||
obj._str_epsneg = float_to_str(epsneg)
|
||||
obj._str_eps = float_to_str(eps)
|
||||
|
@ -2910,11 +2910,14 @@ def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
@_wraps(np.nansum, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None):
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
lax._check_user_dtype_supported(dtype, "nanprod")
|
||||
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims)
|
||||
|
||||
# Work around a sphinx documentation warning in NumPy 1.22.
|
||||
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
|
||||
|
||||
@_wraps(np.nanprod, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
@ -6322,28 +6325,35 @@ def corrcoef(x, y=None, rowvar=True):
|
||||
|
||||
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims'))
|
||||
'keepdims', 'method'))
|
||||
def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
overwrite_input=False, interpolation="linear", keepdims=False):
|
||||
overwrite_input=False, method="linear", keepdims=False,
|
||||
interpolation=None):
|
||||
_check_arraylike("quantile", a, q)
|
||||
if overwrite_input or out is not None:
|
||||
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
|
||||
"out != None")
|
||||
raise ValueError(msg)
|
||||
return _quantile(a, q, axis, interpolation, keepdims, False)
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(a, q, axis, interpolation or method, keepdims, False)
|
||||
|
||||
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims'))
|
||||
'keepdims', 'method'))
|
||||
def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, interpolation="linear",
|
||||
keepdims=False):
|
||||
out=None, overwrite_input=False, method="linear",
|
||||
keepdims=False, interpolation=None):
|
||||
_check_arraylike("nanquantile", a, q)
|
||||
if overwrite_input or out is not None:
|
||||
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
|
||||
"out != None")
|
||||
raise ValueError(msg)
|
||||
return _quantile(a, q, axis, interpolation, keepdims, True)
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(a, q, axis, interpolation or method, keepdims, True)
|
||||
|
||||
def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
|
||||
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
|
||||
@ -6531,26 +6541,27 @@ def _piecewise(x, condlist, consts, funcs, *args, **kw):
|
||||
|
||||
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims'))
|
||||
'keepdims', 'method'))
|
||||
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, interpolation="linear",
|
||||
keepdims=False):
|
||||
out=None, overwrite_input=False, method="linear",
|
||||
keepdims=False, interpolation=None):
|
||||
_check_arraylike("percentile", a, q)
|
||||
a, q = _promote_dtypes_inexact(a, q)
|
||||
q = true_divide(q, 100.0)
|
||||
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, keepdims=keepdims)
|
||||
interpolation=interpolation, method=method, keepdims=keepdims)
|
||||
|
||||
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
'keepdims'))
|
||||
'keepdims', 'method'))
|
||||
def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, interpolation="linear",
|
||||
keepdims=False):
|
||||
out=None, overwrite_input=False, method="linear",
|
||||
keepdims=False, interpolation=None):
|
||||
_check_arraylike("nanpercentile", a, q)
|
||||
q = true_divide(q, float32(100.0))
|
||||
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, keepdims=keepdims)
|
||||
interpolation=interpolation, method=method,
|
||||
keepdims=keepdims)
|
||||
|
||||
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
@ -6558,7 +6569,7 @@ def median(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
overwrite_input=False, keepdims=False):
|
||||
_check_arraylike("median", a)
|
||||
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
keepdims=keepdims, interpolation='midpoint')
|
||||
keepdims=keepdims, method='midpoint')
|
||||
|
||||
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
|
||||
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
|
||||
@ -6567,7 +6578,7 @@ def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
_check_arraylike("nanmedian", a)
|
||||
return nanquantile(a, 0.5, axis=axis, out=out,
|
||||
overwrite_input=overwrite_input, keepdims=keepdims,
|
||||
interpolation='midpoint')
|
||||
method='midpoint')
|
||||
|
||||
|
||||
def _astype(arr, dtype):
|
||||
|
@ -14,5 +14,7 @@ filterwarnings =
|
||||
ignore:.*experimental feature
|
||||
ignore:index.*is deprecated.*:DeprecationWarning
|
||||
ignore:jax.experimental.* is deprecated, import jax.example_libraries.* instead:FutureWarning
|
||||
# numpy uses distutils which is deprecated
|
||||
ignore:The distutils.* is deprecated.*:DeprecationWarning
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst"
|
||||
|
@ -235,8 +235,8 @@ JAX_COMPOUND_OP_RECORDS = [
|
||||
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("fix", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_default, [], check_dtypes=False),
|
||||
op_record("floor_divide", 2, number_dtypes, all_shapes,
|
||||
jtu.rand_nonzero, ["rev"]),
|
||||
op_record("floor_divide", 2, float_dtypes + int_dtypes,
|
||||
all_shapes, jtu.rand_nonzero, ["rev"]),
|
||||
op_record("floor_divide", 2, unsigned_dtypes, all_shapes,
|
||||
jtu.rand_nonzero, ["rev"]),
|
||||
op_record("fmin", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
|
||||
@ -4430,17 +4430,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}_interpolation={}".format(
|
||||
"_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}_method={}".format(
|
||||
op,
|
||||
jtu.format_shape_dtype_string(a_shape, a_dtype),
|
||||
jtu.format_shape_dtype_string(q_shape, q_dtype),
|
||||
axis, keepdims, interpolation),
|
||||
axis, keepdims, method),
|
||||
"a_rng": jtu.rand_some_nan,
|
||||
"q_rng": q_rng, "op": op,
|
||||
"a_shape": a_shape, "a_dtype": a_dtype,
|
||||
"q_shape": q_shape, "q_dtype": q_dtype, "axis": axis,
|
||||
"keepdims": keepdims,
|
||||
"interpolation": interpolation}
|
||||
"method": method}
|
||||
for (op, q_rng) in (
|
||||
("percentile", partial(jtu.rand_uniform, low=0., high=100.)),
|
||||
("quantile", partial(jtu.rand_uniform, low=0., high=1.)),
|
||||
@ -4456,10 +4456,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for q_dtype in [np.float32]
|
||||
for q_shape in scalar_shapes + [(4,)]
|
||||
for keepdims in [False, True]
|
||||
for interpolation in ['linear', 'lower', 'higher', 'nearest',
|
||||
'midpoint']))
|
||||
for method in ['linear', 'lower', 'higher', 'nearest', 'midpoint']))
|
||||
def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype,
|
||||
axis, keepdims, interpolation):
|
||||
axis, keepdims, method):
|
||||
a_rng = a_rng(self.rng())
|
||||
q_rng = q_rng(self.rng())
|
||||
if "median" in op:
|
||||
@ -4470,10 +4469,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def np_fun(*args):
|
||||
args = [x if jnp.result_type(x) != jnp.bfloat16 else
|
||||
np.asarray(x, np.float32) for x in args]
|
||||
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
|
||||
interpolation=interpolation)
|
||||
if numpy_version <= (1, 22):
|
||||
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
|
||||
interpolation=method)
|
||||
else:
|
||||
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
|
||||
method=method)
|
||||
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims,
|
||||
interpolation=interpolation)
|
||||
method=method)
|
||||
|
||||
# TODO(phawkins): we currently set dtype=False because we aren't as
|
||||
# aggressive about promoting to float64. It's not clear we want to mimic
|
||||
@ -5897,6 +5900,8 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
|
||||
unsupported_params = {
|
||||
'angle': ['deg'],
|
||||
'argmax': ['keepdims'],
|
||||
'argmin': ['keepdims'],
|
||||
'asarray': ['like'],
|
||||
'broadcast_to': ['subok', 'array'],
|
||||
'clip': ['kwargs'],
|
||||
@ -5912,6 +5917,14 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'histogram': ['normed'],
|
||||
'histogram2d': ['normed'],
|
||||
'histogramdd': ['normed'],
|
||||
'nanargmax': ['out', 'keepdims'],
|
||||
'nanargmin': ['out', 'keepdims'],
|
||||
'nanmax': ['initial', 'where'],
|
||||
'nanmean': ['where'],
|
||||
'nanmin': ['initial', 'where'],
|
||||
'nanprod': ['initial', 'where'],
|
||||
'nanstd': ['where'],
|
||||
'nanvar': ['where'],
|
||||
'ones': ['order', 'like'],
|
||||
'ones_like': ['subok', 'order'],
|
||||
'tri': ['like'],
|
||||
@ -5934,6 +5947,9 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
# Some signatures have changed; skip for older numpy versions.
|
||||
if numpy_version < (1, 19) and name in ['einsum_path', 'gradient', 'isscalar']:
|
||||
continue
|
||||
if numpy_version < (1, 22) and name in ['quantile', 'nanquantile',
|
||||
'percentile', 'nanpercentile']:
|
||||
continue
|
||||
# Note: can't use inspect.getfullargspec due to numpy issue
|
||||
# https://github.com/numpy/numpy/issues/12225
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user