mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 10:16:07 +00:00

* allow rc2 in numpy versions when parsed by tests. * don't cast np.empty(), which can lead to cast errors. * NumPy 1.24 now warns on overflowing scalar int to array casts in more places.
652 lines
28 KiB
Python
652 lines
28 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import collections
|
|
from functools import partial
|
|
import itertools
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import numpy as jnp
|
|
|
|
from jax._src import dtypes
|
|
from jax._src import test_util as jtu
|
|
|
|
from jax.config import config
|
|
config.parse_flags_with_absl()
|
|
FLAGS = config.FLAGS
|
|
|
|
numpy_version = jtu.numpy_version()
|
|
|
|
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
|
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
|
one_dim_array_shapes = [(1,), (6,), (12,)]
|
|
empty_array_shapes = [(0,), (0, 4), (3, 0),]
|
|
|
|
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
|
|
array_shapes = nonempty_array_shapes + empty_array_shapes
|
|
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
|
|
nonempty_shapes = scalar_shapes + nonempty_array_shapes
|
|
all_shapes = scalar_shapes + array_shapes
|
|
|
|
float_dtypes = jtu.dtypes.all_floating
|
|
complex_dtypes = jtu.dtypes.complex
|
|
int_dtypes = jtu.dtypes.all_integer
|
|
unsigned_dtypes = jtu.dtypes.all_unsigned
|
|
bool_dtypes = jtu.dtypes.boolean
|
|
default_dtypes = float_dtypes + int_dtypes
|
|
inexact_dtypes = float_dtypes + complex_dtypes
|
|
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
|
|
all_dtypes = number_dtypes + bool_dtypes
|
|
|
|
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
|
|
|
def _valid_dtypes_for_shape(shape, dtypes):
|
|
# Not all (shape, dtype) pairs are valid. In particular, Python scalars only
|
|
# have one type in each category (float, bool, etc.)
|
|
if shape is jtu.PYTHON_SCALAR_SHAPE:
|
|
return [t for t in dtypes if t in python_scalar_dtypes]
|
|
return dtypes
|
|
|
|
def _shape_and_dtypes(shapes, dtypes):
|
|
for shape in shapes:
|
|
for dtype in _valid_dtypes_for_shape(shape, dtypes):
|
|
yield (shape, dtype)
|
|
|
|
def _compatible_shapes(shape):
|
|
if np.ndim(shape) == 0 or shape in scalar_shapes:
|
|
return [shape]
|
|
return (shape[n:] for n in range(len(shape) + 1))
|
|
|
|
def _get_y_shapes(y_dtype, shape, rowvar):
|
|
# Helper function for testCov.
|
|
if y_dtype is None:
|
|
return [None]
|
|
if len(shape) == 1:
|
|
return [shape]
|
|
elif rowvar or shape[0] == 1:
|
|
return [(1, shape[-1]), (2, shape[-1]), (5, shape[-1])]
|
|
return [(shape[0], 1), (shape[0], 2), (shape[0], 5)]
|
|
|
|
|
|
OpRecord = collections.namedtuple(
|
|
"OpRecord",
|
|
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
|
|
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])
|
|
|
|
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
|
test_name=None, check_dtypes=True,
|
|
tolerance=None, inexact=False, kwargs=None):
|
|
test_name = test_name or name
|
|
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
|
test_name, check_dtypes, tolerance, inexact, kwargs)
|
|
|
|
JAX_REDUCER_RECORDS = [
|
|
op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [],
|
|
inexact=True),
|
|
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
|
|
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
|
|
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_some_nan,
|
|
[], inexact=True),
|
|
op_record("nanprod", 1, all_dtypes, all_shapes, jtu.rand_some_nan, []),
|
|
op_record("nansum", 1, number_dtypes, all_shapes, jtu.rand_some_nan, []),
|
|
]
|
|
|
|
JAX_REDUCER_INITIAL_RECORDS = [
|
|
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
|
|
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, [],
|
|
tolerance={jnp.bfloat16: 2e-2}),
|
|
op_record("max", 1, all_dtypes, all_shapes, jtu.rand_default, []),
|
|
op_record("min", 1, all_dtypes, all_shapes, jtu.rand_default, []),
|
|
]
|
|
if numpy_version >= (1, 22): # initial & where keywords added in numpy 1.22
|
|
JAX_REDUCER_INITIAL_RECORDS += [
|
|
op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_small_positive, []),
|
|
op_record("nansum", 1, inexact_dtypes, all_shapes, jtu.rand_default, [],
|
|
tolerance={jnp.bfloat16: 3e-2}),
|
|
op_record("nanmax", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
|
|
op_record("nanmin", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
|
|
]
|
|
|
|
JAX_REDUCER_WHERE_NO_INITIAL_RECORDS = [
|
|
op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
|
|
op_record("any", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
|
|
op_record("mean", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
|
inexact=True),
|
|
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
|
inexact=True),
|
|
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
|
inexact=True),
|
|
]
|
|
if numpy_version >= (1, 22): # where keyword added in numpy 1.22
|
|
JAX_REDUCER_WHERE_NO_INITIAL_RECORDS += [
|
|
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
|
|
inexact=True, tolerance={np.float16: 3e-3}),
|
|
op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
|
|
inexact=True, tolerance={np.float16: 3e-3}),
|
|
op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
|
|
inexact=True, tolerance={np.float16: 1e-3}),
|
|
]
|
|
|
|
JAX_REDUCER_NO_DTYPE_RECORDS = [
|
|
op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
|
|
op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
|
|
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
|
|
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
|
|
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
|
inexact=True, tolerance={jnp.bfloat16: 2e-2}),
|
|
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
|
inexact=True),
|
|
op_record("nanmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
|
op_record("nanmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
|
op_record("nanvar", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
|
|
[], inexact=True),
|
|
op_record("nanstd", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
|
|
[], inexact=True),
|
|
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
|
|
]
|
|
|
|
JAX_REDUCER_PROMOTE_INT_RECORDS = [
|
|
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
|
|
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
|
|
]
|
|
|
|
def _reducer_output_dtype(name: str, input_dtype: np.dtype, promote_integers: bool = True) -> np.dtype:
|
|
if name in ['sum', 'prod', 'nansum', 'nanprod']:
|
|
if input_dtype == bool:
|
|
input_dtype = dtypes.to_numeric_dtype(input_dtype)
|
|
if promote_integers:
|
|
if dtypes.issubdtype(input_dtype, np.integer):
|
|
default_int = dtypes.canonicalize_dtype(
|
|
dtypes.uint if dtypes.issubdtype(input_dtype, np.unsignedinteger) else dtypes.int_)
|
|
if np.iinfo(input_dtype).bits < np.iinfo(default_int).bits:
|
|
return default_int
|
|
return input_dtype
|
|
|
|
|
|
class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|
"""Tests for LAX-backed Numpy reduction operations."""
|
|
|
|
def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
|
|
def f():
|
|
out = [rng(shape, dtype or jnp.float_)
|
|
for shape, dtype in zip(shapes, dtypes)]
|
|
if np_arrays:
|
|
return out
|
|
return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a
|
|
for a in out]
|
|
return f
|
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
jtu.sample_product_testcases(
|
|
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
|
[dict(shape=shape, axis=axis, dtype=dtype)
|
|
for shape in rec.shapes
|
|
for dtype in rec.dtypes
|
|
for axis in list(range(-len(shape), len(shape))) + [None]
|
|
if jtu.is_valid_shape(shape, dtype)
|
|
],
|
|
out_dtype=[out_dtype for out_dtype in [None] + rec.dtypes
|
|
if out_dtype not in unsigned_dtypes],
|
|
keepdims=[False, True],
|
|
)
|
|
for rec in JAX_REDUCER_RECORDS
|
|
))
|
|
def testReducer(self, name, rng_factory, shape, dtype, out_dtype,
|
|
axis, keepdims, inexact):
|
|
np_op = getattr(np, name)
|
|
jnp_op = getattr(jnp, name)
|
|
rng = rng_factory(self.rng())
|
|
@jtu.ignore_warning(category=np.ComplexWarning)
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="mean of empty slice.*")
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="overflow encountered.*")
|
|
def np_fun(x):
|
|
x = np.asarray(x)
|
|
if inexact:
|
|
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
|
x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32)
|
|
t = out_dtype if out_dtype != jnp.bfloat16 else np.float32
|
|
if t is None:
|
|
t = _reducer_output_dtype(name, x_cast.dtype)
|
|
return np_op(x_cast, axis, dtype=t, keepdims=keepdims)
|
|
|
|
jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
|
|
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
tol_spec = {np.float16: 1e-2, np.int16: 2e-7, np.int32: 1E-3,
|
|
np.float32: 1e-3, np.complex64: 1e-3, np.float64: 1e-5,
|
|
np.complex128: 1e-5}
|
|
tol = jtu.tolerance(dtype, tol_spec)
|
|
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
|
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
|
|
tol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
|
|
rtol=tol)
|
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
jtu.sample_product_testcases(
|
|
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,
|
|
tolerance=rec.tolerance)],
|
|
[dict(shape=shape, axis=axis, dtype=dtype)
|
|
for shape in rec.shapes for dtype in rec.dtypes
|
|
for axis in list(range(-len(shape), len(shape))) + [None]
|
|
if jtu.is_valid_shape(shape, dtype)
|
|
],
|
|
keepdims=[False, True],
|
|
)
|
|
for rec in JAX_REDUCER_NO_DTYPE_RECORDS
|
|
))
|
|
def testReducerNoDtype(self, name, rng_factory, shape, dtype, axis,
|
|
keepdims, inexact, tolerance):
|
|
np_op = getattr(np, name)
|
|
jnp_op = getattr(jnp, name)
|
|
rng = rng_factory(self.rng())
|
|
is_bf16_nan_test = (dtype == jnp.bfloat16 and
|
|
rng_factory.__name__ == 'rand_some_nan')
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="Degrees of freedom <= 0 for slice.*")
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="All-NaN (slice|axis) encountered.*")
|
|
def np_fun(x):
|
|
x = np.asarray(x)
|
|
if inexact:
|
|
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
|
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
|
res = np_op(x_cast, axis, keepdims=keepdims)
|
|
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
|
return res
|
|
|
|
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
tol = jtu.join_tolerance({np.float16: 0.002},
|
|
tolerance or jtu.default_tolerance())
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
jtu.sample_product_testcases(
|
|
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
|
[dict(shape=shape, axis=axis, dtype=dtype)
|
|
for shape in rec.shapes for dtype in rec.dtypes
|
|
for axis in list(range(-len(shape), len(shape))) + [None]
|
|
if jtu.is_valid_shape(shape, dtype)
|
|
],
|
|
initial=[0, 1],
|
|
keepdims=[False, True],
|
|
)
|
|
for rec in JAX_REDUCER_INITIAL_RECORDS
|
|
))
|
|
def testReducerInitial(self, name, rng_factory, shape, dtype, axis,
|
|
keepdims, initial, inexact):
|
|
np_op = getattr(np, name)
|
|
jnp_op = getattr(jnp, name)
|
|
rng = rng_factory(self.rng())
|
|
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="Degrees of freedom <= 0 for slice.*")
|
|
@jtu.ignore_warning(category=np.ComplexWarning)
|
|
def np_fun(x):
|
|
x = np.asarray(x)
|
|
if inexact:
|
|
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
|
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
|
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
|
|
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
|
return res.astype(_reducer_output_dtype(name, x.dtype))
|
|
|
|
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial)
|
|
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
tol = {jnp.bfloat16: 3E-2}
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol, atol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
jtu.sample_product_testcases(
|
|
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
|
[dict(shape=shape, axis=axis, dtype=dtype)
|
|
for shape in rec.shapes for dtype in rec.dtypes
|
|
for axis in list(range(-len(shape), len(shape))) + [None]
|
|
if jtu.is_valid_shape(shape, dtype)
|
|
],
|
|
initial=[0, 1],
|
|
keepdims=[False, True],
|
|
promote_integers=[False, True],
|
|
)
|
|
for rec in JAX_REDUCER_PROMOTE_INT_RECORDS
|
|
))
|
|
def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis,
|
|
keepdims, initial, inexact, promote_integers):
|
|
np_op = getattr(np, name)
|
|
jnp_op = getattr(jnp, name)
|
|
rng = rng_factory(self.rng())
|
|
is_bf16_nan_test = (dtype == jnp.bfloat16 and
|
|
rng_factory.__name__ == 'rand_some_nan')
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="Degrees of freedom <= 0 for slice.*")
|
|
@jtu.ignore_warning(category=np.ComplexWarning)
|
|
def np_fun(x):
|
|
x = np.asarray(x)
|
|
if inexact:
|
|
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
|
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
|
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
|
|
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
|
return res.astype(_reducer_output_dtype(name, x.dtype, promote_integers))
|
|
|
|
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers)
|
|
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
tol = {jnp.bfloat16: 3E-2}
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
jtu.sample_product_testcases(
|
|
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact)],
|
|
[dict(shape=shape, axis=axis)
|
|
for shape in rec.shapes if np.prod(shape) == 0
|
|
for axis in range(-len(shape), len(shape)) if shape[axis] >= 1
|
|
],
|
|
dtype=rec.dtypes,
|
|
keepdims=[False, True],
|
|
)
|
|
for rec in JAX_REDUCER_INITIAL_RECORDS
|
|
))
|
|
def testReducerNoInitialZeroDims(self, name, rng_factory, shape, dtype, axis,
|
|
keepdims, inexact):
|
|
np_op = getattr(np, name)
|
|
jnp_op = getattr(jnp, name)
|
|
rng = rng_factory(self.rng())
|
|
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="Degrees of freedom <= 0 for slice.*")
|
|
@jtu.ignore_warning(category=np.ComplexWarning)
|
|
def np_fun(x):
|
|
x = np.asarray(x)
|
|
if inexact:
|
|
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
|
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
|
res = np_op(x_cast, axis, keepdims=keepdims)
|
|
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
|
return res.astype(_reducer_output_dtype(name, x.dtype))
|
|
|
|
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
|
|
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
tol = {jnp.bfloat16: 3E-2}
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
jtu.sample_product_testcases(
|
|
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,
|
|
tol=rec.tolerance)],
|
|
[dict(shape=shape, axis=axis, dtype=dtype, whereshape=whereshape)
|
|
for shape in rec.shapes for dtype in rec.dtypes
|
|
for axis in list(range(-len(shape), len(shape))) + [None]
|
|
if jtu.is_valid_shape(shape, dtype)
|
|
for whereshape in _compatible_shapes(shape)
|
|
],
|
|
initial=[0, 1],
|
|
keepdims=[False, True],
|
|
)
|
|
for rec in JAX_REDUCER_INITIAL_RECORDS
|
|
))
|
|
def testReducerWhere(self, name, rng_factory, shape, dtype, axis,
|
|
keepdims, initial, inexact, whereshape, tol):
|
|
np_op = getattr(np, name)
|
|
jnp_op = getattr(jnp, name)
|
|
if (shape in [()] + scalar_shapes and
|
|
dtype in [jnp.int16, jnp.uint16] and
|
|
jnp_op in [jnp.min, jnp.max]):
|
|
self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.")
|
|
rng = rng_factory(self.rng())
|
|
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
|
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
|
|
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="Degrees of freedom <= 0 for slice.*")
|
|
@jtu.ignore_warning(category=np.ComplexWarning)
|
|
def np_fun(x):
|
|
x = np.asarray(x)
|
|
if inexact:
|
|
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
|
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
|
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where)
|
|
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
|
return res.astype(_reducer_output_dtype(name, x.dtype))
|
|
|
|
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where)
|
|
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
jtu.sample_product_testcases(
|
|
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,
|
|
tol=rec.tolerance)],
|
|
[dict(shape=shape, axis=axis, dtype=dtype, whereshape=whereshape)
|
|
for shape in rec.shapes for dtype in rec.dtypes
|
|
for whereshape in _compatible_shapes(shape)
|
|
for axis in list(range(-len(shape), len(shape))) + [None]
|
|
if jtu.is_valid_shape(shape, dtype)
|
|
],
|
|
keepdims=[False, True],
|
|
) for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS
|
|
))
|
|
def testReducerWhereNoInitial(self, name, rng_factory, shape, dtype, axis,
|
|
keepdims, inexact, whereshape, tol):
|
|
np_op = getattr(np, name)
|
|
jnp_op = getattr(jnp, name)
|
|
rng = rng_factory(self.rng())
|
|
is_bf16_nan_test = dtype == jnp.bfloat16
|
|
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
|
|
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="Degrees of freedom <= 0 for slice.*")
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="Mean of empty slice.*")
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="invalid value encountered.*")
|
|
@jtu.ignore_warning(category=np.ComplexWarning)
|
|
def np_fun(x):
|
|
x = np.asarray(x)
|
|
if inexact:
|
|
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
|
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
|
res = np_op(x_cast, axis, keepdims=keepdims, where=where)
|
|
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
|
return res
|
|
|
|
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
|
|
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"):
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
|
|
|
def testReductionOfOutOfBoundsAxis(self): # Issue 888
|
|
x = jnp.ones((3, 4))
|
|
self.assertRaises(ValueError, lambda: jnp.sum(x, axis=2))
|
|
|
|
def testReductionWithRepeatedAxisError(self):
|
|
with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"):
|
|
jnp.sum(jnp.arange(3), (0, 0))
|
|
|
|
@jtu.sample_product(
|
|
[dict(shape=shape, dtype=dtype, axis=axis, weights_shape=weights_shape)
|
|
for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes)
|
|
for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))]
|
|
# `weights_shape` is either `None`, same as the averaged axis, or same as
|
|
# that of the input
|
|
for weights_shape in ([None, shape] if axis is None or len(shape) == 1 or isinstance(axis, tuple)
|
|
else [None, (shape[axis],), shape])
|
|
],
|
|
keepdims=([False, True] if numpy_version >= (1, 23) else [None]),
|
|
returned=[False, True],
|
|
)
|
|
def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims):
|
|
rng = jtu.rand_default(self.rng())
|
|
kwds = dict(returned=returned)
|
|
if keepdims is not None:
|
|
kwds['keepdims'] = keepdims
|
|
if weights_shape is None:
|
|
np_fun = lambda x: np.average(x, axis, **kwds)
|
|
jnp_fun = lambda x: jnp.average(x, axis, **kwds)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
else:
|
|
np_fun = lambda x, weights: np.average(x, axis, weights, **kwds)
|
|
jnp_fun = lambda x, weights: jnp.average(x, axis, weights, **kwds)
|
|
args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)]
|
|
np_fun = jtu.promote_like_jnp(np_fun, inexact=True)
|
|
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5,
|
|
np.float64: 1e-12, np.complex64: 1e-5}
|
|
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE and numpy_version >= (1, 22)
|
|
if numpy_version == (1, 23, 0) and keepdims and weights_shape is not None and axis is not None:
|
|
# Known failure: https://github.com/numpy/numpy/issues/21850
|
|
pass
|
|
else:
|
|
try:
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
|
check_dtypes=check_dtypes, tol=tol)
|
|
except ZeroDivisionError:
|
|
self.skipTest("don't support checking for ZeroDivisionError")
|
|
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
|
|
rtol=tol, atol=tol)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(5,), (10, 5)],
|
|
dtype=all_dtypes,
|
|
out_dtype=inexact_dtypes,
|
|
axis=[None, 0, -1],
|
|
ddof=[0, 1, 2],
|
|
keepdims=[False, True],
|
|
)
|
|
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="Degrees of freedom <= 0 for slice.")
|
|
@jtu.ignore_warning(category=np.ComplexWarning)
|
|
def np_fun(x):
|
|
# Numpy fails with bfloat16 inputs
|
|
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
|
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
|
|
axis=axis, ddof=ddof, keepdims=keepdims)
|
|
return out.astype(out_dtype)
|
|
jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
|
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
|
|
np.float64: 1e-3, np.complex128: 1e-6})
|
|
if (jnp.issubdtype(dtype, jnp.complexfloating) and
|
|
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
|
|
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
|
|
else:
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
|
tol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
|
|
atol=tol)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(5,), (10, 5)],
|
|
dtype=all_dtypes,
|
|
out_dtype=inexact_dtypes,
|
|
axis=[None, 0, -1],
|
|
ddof=[0, 1, 2],
|
|
keepdims=[False, True],
|
|
)
|
|
def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
|
|
rng = jtu.rand_some_nan(self.rng())
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
|
@jtu.ignore_warning(category=RuntimeWarning,
|
|
message="Degrees of freedom <= 0 for slice.")
|
|
@jtu.ignore_warning(category=np.ComplexWarning)
|
|
def np_fun(x):
|
|
# Numpy fails with bfloat16 inputs
|
|
out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
|
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
|
|
axis=axis, ddof=ddof, keepdims=keepdims)
|
|
return out.astype(out_dtype)
|
|
jnp_fun = partial(jnp.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
|
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
|
|
np.float64: 1e-3, np.complex64: 1e-3,
|
|
np.complex128: 3e-4})
|
|
if (jnp.issubdtype(dtype, jnp.complexfloating) and
|
|
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
|
|
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
|
|
else:
|
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
|
tol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
|
|
atol=tol)
|
|
|
|
def testNanStdGrad(self):
|
|
# Regression test for https://github.com/google/jax/issues/8128
|
|
x = jnp.arange(5.0).at[0].set(jnp.nan)
|
|
y = jax.grad(jnp.nanvar)(x)
|
|
self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75]), check_dtypes=False)
|
|
|
|
z = jax.grad(jnp.nanstd)(x)
|
|
self.assertEqual(jnp.isnan(z).sum(), 0)
|
|
|
|
@jtu.sample_product(
|
|
[dict(shape=shape, dtype=dtype, y_dtype=y_dtype, rowvar=rowvar,
|
|
y_shape=y_shape)
|
|
for shape in [(5,), (10, 5), (5, 10)]
|
|
for dtype in all_dtypes
|
|
for y_dtype in [None, dtype]
|
|
for rowvar in [True, False]
|
|
for y_shape in _get_y_shapes(y_dtype, shape, rowvar)
|
|
],
|
|
bias=[True, False],
|
|
ddof=[None, 2, 3],
|
|
fweights=[True, False],
|
|
aweights=[True, False],
|
|
)
|
|
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
|
@jax.default_matmul_precision('float32')
|
|
def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights):
|
|
rng = jtu.rand_default(self.rng())
|
|
wrng = jtu.rand_positive(self.rng())
|
|
wdtype = np.real(dtype(0)).dtype
|
|
wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1]
|
|
|
|
args_maker = lambda: [rng(shape, dtype),
|
|
rng(y_shape, y_dtype) if y_dtype else None,
|
|
wrng(wshape, int) if fweights else None,
|
|
wrng(wshape, wdtype) if aweights else None]
|
|
kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias)
|
|
np_fun = lambda m, y, f, a: np.cov(m, y, fweights=f, aweights=a, **kwargs)
|
|
jnp_fun = lambda m, y, f, a: jnp.cov(m, y, fweights=f, aweights=a, **kwargs)
|
|
tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5,
|
|
np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13}
|
|
tol = jtu.join_tolerance(tol, jtu.tolerance(dtype))
|
|
self._CheckAgainstNumpy(
|
|
np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
|
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
|
|
rtol=tol)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|