mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Split parts of lax_numpy_test.py into separate test files.
Why? The main test file is getting too big and this hinders iteration on individual tests PiperOrigin-RevId: 478130215
This commit is contained in:
parent
849f837b6a
commit
439217644a
28
tests/BUILD
28
tests/BUILD
@ -320,9 +320,31 @@ jax_test(
|
||||
srcs = ["lax_numpy_test.py"],
|
||||
pjrt_c_api_bypass = True,
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
"gpu": 40,
|
||||
"tpu": 20,
|
||||
"cpu": 20,
|
||||
"gpu": 20,
|
||||
"tpu": 10,
|
||||
},
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_numpy_operators_test",
|
||||
srcs = ["lax_numpy_operators_test.py"],
|
||||
pjrt_c_api_bypass = True,
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 10,
|
||||
"tpu": 5,
|
||||
},
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_numpy_reducers_test",
|
||||
srcs = ["lax_numpy_reducers_test.py"],
|
||||
pjrt_c_api_bypass = True,
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 10,
|
||||
"tpu": 5,
|
||||
},
|
||||
)
|
||||
|
||||
|
636
tests/lax_numpy_operators_test.py
Normal file
636
tests/lax_numpy_operators_test.py
Normal file
@ -0,0 +1,636 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import collections
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools
|
||||
import operator
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.ops
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
||||
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
||||
one_dim_array_shapes = [(1,), (6,), (12,)]
|
||||
empty_array_shapes = [(0,), (0, 4), (3, 0),]
|
||||
|
||||
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
|
||||
array_shapes = nonempty_array_shapes + empty_array_shapes
|
||||
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
|
||||
nonempty_shapes = scalar_shapes + nonempty_array_shapes
|
||||
all_shapes = scalar_shapes + array_shapes
|
||||
|
||||
float_dtypes = jtu.dtypes.all_floating
|
||||
complex_dtypes = jtu.dtypes.complex
|
||||
int_dtypes = jtu.dtypes.all_integer
|
||||
unsigned_dtypes = jtu.dtypes.all_unsigned
|
||||
bool_dtypes = jtu.dtypes.boolean
|
||||
default_dtypes = float_dtypes + int_dtypes
|
||||
inexact_dtypes = float_dtypes + complex_dtypes
|
||||
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
|
||||
all_dtypes = number_dtypes + bool_dtypes
|
||||
|
||||
|
||||
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
||||
|
||||
# uint64 is problematic because with any uint type it promotes to float:
|
||||
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
|
||||
|
||||
def _indexer_with_default_outputs(indexer, use_defaults=True):
|
||||
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
|
||||
class Indexer:
|
||||
@partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults)
|
||||
def __getitem__(self, *args):
|
||||
return indexer.__getitem__(*args)
|
||||
return Indexer()
|
||||
|
||||
def _valid_dtypes_for_shape(shape, dtypes):
|
||||
# Not all (shape, dtype) pairs are valid. In particular, Python scalars only
|
||||
# have one type in each category (float, bool, etc.)
|
||||
if shape is jtu.PYTHON_SCALAR_SHAPE:
|
||||
return [t for t in dtypes if t in python_scalar_dtypes]
|
||||
return dtypes
|
||||
|
||||
|
||||
OpRecord = collections.namedtuple(
|
||||
"OpRecord",
|
||||
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
|
||||
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])
|
||||
|
||||
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
||||
test_name=None, check_dtypes=True,
|
||||
tolerance=None, inexact=False, kwargs=None):
|
||||
test_name = test_name or name
|
||||
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
||||
test_name, check_dtypes, tolerance, inexact, kwargs)
|
||||
|
||||
JAX_ONE_TO_ONE_OP_RECORDS = [
|
||||
op_record("abs", 1, all_dtypes,
|
||||
all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("ceil", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_default, [], check_dtypes=False),
|
||||
op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
inexact=True),
|
||||
op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("float_power", 2, inexact_dtypes, all_shapes,
|
||||
partial(jtu.rand_default, scale=1), ["rev"],
|
||||
tolerance={jnp.bfloat16: 1e-2, np.float32: 1e-3,
|
||||
np.float64: 1e-12, np.complex64: 2e-4,
|
||||
np.complex128: 1e-12}, check_dtypes=False),
|
||||
op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("floor", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_default, [], check_dtypes=False),
|
||||
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
|
||||
check_dtypes=False),
|
||||
op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False),
|
||||
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
|
||||
inexact=True),
|
||||
op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
|
||||
op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []),
|
||||
op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
|
||||
op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
|
||||
op_record("maximum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []),
|
||||
op_record("minimum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []),
|
||||
op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
|
||||
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
|
||||
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes,
|
||||
jtu.rand_some_inf_and_nan, ["rev"]),
|
||||
op_record("trunc", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
|
||||
op_record("trunc", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_some_inf_and_nan, [], check_dtypes=False),
|
||||
op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
inexact=True),
|
||||
op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
inexact=True),
|
||||
op_record("tan", 1, number_dtypes, all_shapes,
|
||||
partial(jtu.rand_uniform, low=-1.5, high=1.5), ["rev"],
|
||||
inexact=True),
|
||||
op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
inexact=True),
|
||||
op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
inexact=True),
|
||||
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
|
||||
# ~float32 precision.
|
||||
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
|
||||
op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
tolerance={np.float64: 1e-7, np.complex128: 1e-7},
|
||||
inexact=True),
|
||||
op_record("arcsin", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
|
||||
inexact=True),
|
||||
op_record("arccos", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
|
||||
inexact=True),
|
||||
op_record("arctan", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
|
||||
inexact=True),
|
||||
op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small, ["rev"],
|
||||
inexact=True),
|
||||
op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
inexact=True, tolerance={np.complex64: 2E-4, np.complex128: 2E-14}),
|
||||
op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
inexact=True, tolerance={np.complex64: 2E-2, np.complex128: 2E-12}),
|
||||
op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
|
||||
inexact=True, tolerance={np.float64: 1e-9}),
|
||||
]
|
||||
|
||||
JAX_COMPOUND_OP_RECORDS = [
|
||||
# angle has inconsistent 32/64-bit return types across numpy versions.
|
||||
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
|
||||
check_dtypes=False, inexact=True),
|
||||
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
|
||||
check_dtypes=False, inexact=True, test_name="angle_deg", kwargs={'deg': True}),
|
||||
op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_some_inf, ["rev"],
|
||||
inexact=True),
|
||||
op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"],
|
||||
inexact=True),
|
||||
op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes,
|
||||
jtu.rand_nonzero, []),
|
||||
op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
tolerance={jnp.bfloat16: 4e-2, np.float16: 1e-2}, inexact=True),
|
||||
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
|
||||
# precision.
|
||||
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
|
||||
test_name="expm1_large", tolerance={np.float64: 1e-8}, inexact=True),
|
||||
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive,
|
||||
[], tolerance={np.float64: 1e-8}, inexact=True),
|
||||
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("fix", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_default, [], check_dtypes=False),
|
||||
op_record("floor_divide", 2, default_dtypes + unsigned_dtypes,
|
||||
all_shapes, jtu.rand_nonzero, ["rev"]),
|
||||
op_record("fmin", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
|
||||
op_record("fmax", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
|
||||
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
|
||||
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
|
||||
op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
|
||||
op_record("isfinite", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
|
||||
op_record("isinf", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
|
||||
op_record("isnan", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
|
||||
op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
|
||||
op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
|
||||
op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
|
||||
op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
|
||||
op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
|
||||
inexact=True),
|
||||
op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
|
||||
inexact=True),
|
||||
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
|
||||
test_name="log1p_large", tolerance={np.float64: 1e-12},
|
||||
inexact=True),
|
||||
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [],
|
||||
tolerance={np.float64: 1e-12}, inexact=True),
|
||||
op_record("logaddexp", 2, float_dtypes, all_shapes,
|
||||
jtu.rand_some_inf_and_nan, ["rev"],
|
||||
tolerance={np.float64: 1e-12}, inexact=True),
|
||||
op_record("logaddexp2", 2, float_dtypes, all_shapes,
|
||||
jtu.rand_some_inf_and_nan, ["rev"],
|
||||
tolerance={np.float16: 1e-2, np.float64: 2e-14}, inexact=True),
|
||||
op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes,
|
||||
jtu.rand_default, [], check_dtypes=False,
|
||||
tolerance={dtypes.bfloat16: 4e-2, np.float16: 1e-2,
|
||||
np.float64: 1e-12}),
|
||||
op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
|
||||
tolerance={np.complex128: 1e-14}, check_dtypes=False),
|
||||
op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
|
||||
op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
|
||||
tolerance={np.float16: 1e-2}),
|
||||
op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []),
|
||||
op_record("modf", 1, float_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("modf", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_default, [], check_dtypes=False),
|
||||
op_record("rint", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan,
|
||||
[]),
|
||||
op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_default, [], check_dtypes=False),
|
||||
op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
|
||||
# numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately.
|
||||
op_record("copysign", 2, default_dtypes + unsigned_dtypes,
|
||||
all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False),
|
||||
op_record("sinc", 1, [t for t in number_dtypes if t != jnp.bfloat16],
|
||||
all_shapes, jtu.rand_default, ["rev"],
|
||||
tolerance={np.complex64: 1e-5}, inexact=True,
|
||||
check_dtypes=False),
|
||||
op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
|
||||
inexact=True),
|
||||
op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
check_dtypes=False),
|
||||
op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero,
|
||||
["rev"], inexact=True),
|
||||
op_record("ediff1d", 3, [np.int32], all_shapes, jtu.rand_default, [], check_dtypes=False),
|
||||
# TODO(phawkins): np.unwrap does not correctly promote its default period
|
||||
# argument under NumPy 1.21 for bfloat16 inputs. It works fine if we
|
||||
# explicitly pass a bfloat16 value that does not need promition. We should
|
||||
# probably add a custom test harness for unwrap that tests the period
|
||||
# argument anyway.
|
||||
op_record("unwrap", 1, [t for t in float_dtypes if t != dtypes.bfloat16],
|
||||
nonempty_nonscalar_array_shapes,
|
||||
jtu.rand_default, ["rev"],
|
||||
# numpy.unwrap always returns float64
|
||||
check_dtypes=False,
|
||||
# numpy cumsum is inaccurate, see issue #3517
|
||||
tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1}),
|
||||
op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16],
|
||||
all_shapes, jtu.rand_small_positive, []),
|
||||
op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
|
||||
op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
|
||||
]
|
||||
|
||||
JAX_BITWISE_OP_RECORDS = [
|
||||
op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_fullrange, []),
|
||||
op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_fullrange, []),
|
||||
op_record("invert", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_fullrange, []),
|
||||
op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_fullrange, []),
|
||||
op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_fullrange, []),
|
||||
]
|
||||
|
||||
JAX_OPERATOR_OVERLOADS = [
|
||||
op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__le__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
|
||||
tolerance={np.float32: 2e-4, np.complex64: 2e-4, np.complex128: 1e-14}),
|
||||
op_record("__mod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
|
||||
tolerance={np.float16: 1e-1}),
|
||||
op_record("__floordiv__", 2, default_dtypes, all_shapes,
|
||||
jtu.rand_nonzero, []),
|
||||
op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [],
|
||||
inexact=True),
|
||||
op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
# TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2
|
||||
op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []),
|
||||
# TODO(mattjj): investigate these failures
|
||||
# op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
|
||||
# op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
# op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
|
||||
# op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
|
||||
op_record("__lshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []),
|
||||
op_record("__rshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []),
|
||||
]
|
||||
|
||||
JAX_RIGHT_OPERATOR_OVERLOADS = [
|
||||
op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
|
||||
tolerance={np.float32: 2e-4, np.complex64: 1e-3}),
|
||||
op_record("__rmod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
|
||||
tolerance={np.float16: 1e-1}),
|
||||
op_record("__rfloordiv__", 2, default_dtypes, all_shapes,
|
||||
jtu.rand_nonzero, []),
|
||||
op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [],
|
||||
inexact=True),
|
||||
# op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
|
||||
# op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
# op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
|
||||
# op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
|
||||
op_record("__rlshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []),
|
||||
op_record("__rrshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), [])
|
||||
]
|
||||
|
||||
class _OverrideEverything:
|
||||
pass
|
||||
|
||||
for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
|
||||
if rec.nargs == 2:
|
||||
setattr(_OverrideEverything, rec.name, lambda self, other: self)
|
||||
|
||||
class _OverrideNothing:
|
||||
pass
|
||||
|
||||
for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
|
||||
if rec.nargs == 2:
|
||||
setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented)
|
||||
|
||||
|
||||
def _dtypes_are_compatible_for_bitwise_ops(args):
|
||||
if len(args) <= 1:
|
||||
return True
|
||||
is_signed = lambda dtype: jnp.issubdtype(dtype, np.signedinteger)
|
||||
width = lambda dtype: jnp.iinfo(dtype).bits
|
||||
x, y = args
|
||||
if width(x) > width(y):
|
||||
x, y = y, x
|
||||
# The following condition seems a little ad hoc, but seems to capture what
|
||||
# numpy actually implements.
|
||||
return (
|
||||
is_signed(x) == is_signed(y)
|
||||
or (width(x) == 32 and width(y) == 32)
|
||||
or (width(x) == 32 and width(y) == 64 and is_signed(y)))
|
||||
|
||||
def _shapes_are_broadcast_compatible(shapes):
|
||||
try:
|
||||
lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes))
|
||||
except ValueError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _shapes_are_equal_length(shapes):
|
||||
return all(len(shape) == len(shapes[0]) for shape in shapes[1:])
|
||||
|
||||
|
||||
def _promote_like_jnp(fun, inexact=False):
|
||||
"""Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`.
|
||||
|
||||
jnp and np have different type promotion semantics; this decorator allows
|
||||
tests make an np reference implementation act more like an jnp
|
||||
implementation.
|
||||
"""
|
||||
_promote = _promote_dtypes_inexact if inexact else _promote_dtypes
|
||||
def wrapper(*args, **kw):
|
||||
flat_args, tree = tree_util.tree_flatten(args)
|
||||
args = tree_util.tree_unflatten(tree, _promote(*flat_args))
|
||||
return fun(*args, **kw)
|
||||
return wrapper
|
||||
|
||||
|
||||
class JaxNumpyOperatorTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed Numpy operators."""
|
||||
|
||||
def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
|
||||
def f():
|
||||
out = [rng(shape, dtype or jnp.float_)
|
||||
for shape, dtype in zip(shapes, dtypes)]
|
||||
if np_arrays:
|
||||
return out
|
||||
return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a
|
||||
for a in out]
|
||||
return f
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
||||
dtypes),
|
||||
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance,
|
||||
"inexact": rec.inexact, "kwargs": rec.kwargs or {}}
|
||||
for shapes in filter(
|
||||
_shapes_are_broadcast_compatible,
|
||||
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
|
||||
for dtypes in itertools.product(
|
||||
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
|
||||
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
|
||||
JAX_COMPOUND_OP_RECORDS)))
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testOp(self, np_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes,
|
||||
tolerance, inexact, kwargs):
|
||||
np_op = partial(np_op, **kwargs)
|
||||
jnp_op = partial(jnp_op, **kwargs)
|
||||
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="invalid value.*")(np_op)
|
||||
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="divide by zero.*")(np_op)
|
||||
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
|
||||
tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes)
|
||||
tol = functools.reduce(jtu.join_tolerance,
|
||||
[tolerance, tol, jtu.default_tolerance()])
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), jnp_op,
|
||||
args_maker, check_dtypes=check_dtypes, tol=tol)
|
||||
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
||||
dtypes),
|
||||
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
|
||||
"tol": rec.tolerance}
|
||||
for shapes in filter(
|
||||
_shapes_are_broadcast_compatible,
|
||||
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
|
||||
for dtypes in itertools.product(
|
||||
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
|
||||
for rec in JAX_OPERATOR_OVERLOADS))
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol):
|
||||
rng = rng_factory(self.rng())
|
||||
# np and jnp arrays have different type promotion rules; force the use of
|
||||
# jnp arrays.
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
|
||||
fun = lambda *xs: getattr(operator, name.strip('_'))(*xs)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
||||
dtypes),
|
||||
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
|
||||
"op_tolerance": rec.tolerance}
|
||||
for shapes in filter(
|
||||
_shapes_are_broadcast_compatible,
|
||||
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
|
||||
for dtypes in itertools.product(
|
||||
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
|
||||
for rec in JAX_RIGHT_OPERATOR_OVERLOADS))
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes,
|
||||
op_tolerance):
|
||||
if shapes[1] is jtu.PYTHON_SCALAR_SHAPE:
|
||||
raise SkipTest("scalars not implemented") # TODO(mattjj): clean up
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
|
||||
fun = lambda fst, snd: getattr(snd, name)(fst)
|
||||
tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype}
|
||||
for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2
|
||||
for othertype in [dict, list, tuple, set]))
|
||||
def testOperatorOverloadErrors(self, name, othertype):
|
||||
# Test that binary operators with builtin collections raise a TypeError
|
||||
# and report the types in the correct order.
|
||||
data = [(1, 2), (2, 3)]
|
||||
arr = jnp.array(data)
|
||||
other = othertype(data)
|
||||
|
||||
if config.jax_array:
|
||||
val_str = 'Array'
|
||||
else:
|
||||
val_str = 'DeviceArray'
|
||||
msg = f"unsupported operand type.* '{val_str}' and '{othertype.__name__}'"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
getattr(arr, name)(other)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype}
|
||||
for rec in JAX_RIGHT_OPERATOR_OVERLOADS if rec.nargs == 2
|
||||
for othertype in [dict, list, tuple, set]))
|
||||
def testRightOperatorOverloadErrors(self, name, othertype):
|
||||
# Test that binary operators with builtin collections raise a TypeError
|
||||
# and report the types in the correct order.
|
||||
data = [(1, 2), (2, 3)]
|
||||
arr = jnp.array(data)
|
||||
other = othertype(data)
|
||||
|
||||
if config.jax_array:
|
||||
val_str = 'Array'
|
||||
else:
|
||||
val_str = 'DeviceArray'
|
||||
msg = f"unsupported operand type.* '{othertype.__name__}' and '{val_str}'"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
getattr(arr, name)(other)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": rec.test_name + f"_{dtype}",
|
||||
"rng_factory": rec.rng_factory,
|
||||
"op_name": rec.name, "dtype": dtype}
|
||||
for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2
|
||||
for dtype in rec.dtypes))
|
||||
def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):
|
||||
rng = rng_factory(self.rng())
|
||||
arg = jax.device_put(rng((), dtype))
|
||||
op = getattr(operator, op_name)
|
||||
|
||||
other = _OverrideEverything()
|
||||
assert op(other, arg) is other
|
||||
assert op(arg, other) is other
|
||||
|
||||
other = _OverrideNothing()
|
||||
if op_name == "__eq__":
|
||||
assert op(other, arg) is False
|
||||
assert op(arg, other) is False
|
||||
elif op_name == "__ne__":
|
||||
assert op(other, arg) is True
|
||||
assert op(arg, other) is True
|
||||
else:
|
||||
with self.assertRaises(TypeError):
|
||||
op(other, arg)
|
||||
with self.assertRaises(TypeError):
|
||||
op(arg, other)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
rec.test_name, shapes, dtypes),
|
||||
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name)}
|
||||
for shapes in filter(
|
||||
_shapes_are_broadcast_compatible,
|
||||
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
|
||||
for dtypes in filter(
|
||||
_dtypes_are_compatible_for_bitwise_ops,
|
||||
itertools.combinations_with_replacement(rec.dtypes, rec.nargs)))
|
||||
for rec in JAX_BITWISE_OP_RECORDS))
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(op.__name__, shapes, dtypes),
|
||||
"op": op, "dtypes": dtypes, "shapes": shapes}
|
||||
for op in [jnp.left_shift, jnp.right_shift]
|
||||
for shapes in filter(
|
||||
_shapes_are_broadcast_compatible,
|
||||
# TODO numpy always promotes to shift dtype for zero-dim shapes:
|
||||
itertools.combinations_with_replacement(nonzerodim_shapes, 2))
|
||||
for dtypes in itertools.product(
|
||||
*(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes))))
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testShiftOpAgainstNumpy(self, op, dtypes, shapes):
|
||||
dtype, shift_dtype = dtypes
|
||||
signed_mix = np.issubdtype(dtype, np.signedinteger) != \
|
||||
np.issubdtype(shift_dtype, np.signedinteger)
|
||||
has_32 = any(np.iinfo(d).bits == 32 for d in dtypes)
|
||||
promoting_to_64 = has_32 and signed_mix
|
||||
if promoting_to_64 and not config.x64_enabled:
|
||||
self.skipTest("np.right_shift/left_shift promoting to int64"
|
||||
"differs from jnp in 32 bit mode.")
|
||||
|
||||
info, shift_info = map(np.iinfo, dtypes)
|
||||
x_rng = jtu.rand_int(self.rng(), low=info.min, high=info.max + 1)
|
||||
# NumPy requires shifts to be non-negative and below the bit width:
|
||||
shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits))
|
||||
args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype))
|
||||
|
||||
np_op = getattr(np, op.__name__)
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
self._CheckAgainstNumpy(np_op, op, args_maker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
584
tests/lax_numpy_reducers_test.py
Normal file
584
tests/lax_numpy_reducers_test.py
Normal file
@ -0,0 +1,584 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import collections
|
||||
from functools import partial
|
||||
import itertools
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
||||
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
||||
one_dim_array_shapes = [(1,), (6,), (12,)]
|
||||
empty_array_shapes = [(0,), (0, 4), (3, 0),]
|
||||
|
||||
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
|
||||
array_shapes = nonempty_array_shapes + empty_array_shapes
|
||||
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
|
||||
nonempty_shapes = scalar_shapes + nonempty_array_shapes
|
||||
all_shapes = scalar_shapes + array_shapes
|
||||
|
||||
float_dtypes = jtu.dtypes.all_floating
|
||||
complex_dtypes = jtu.dtypes.complex
|
||||
int_dtypes = jtu.dtypes.all_integer
|
||||
unsigned_dtypes = jtu.dtypes.all_unsigned
|
||||
bool_dtypes = jtu.dtypes.boolean
|
||||
default_dtypes = float_dtypes + int_dtypes
|
||||
inexact_dtypes = float_dtypes + complex_dtypes
|
||||
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
|
||||
all_dtypes = number_dtypes + bool_dtypes
|
||||
|
||||
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
||||
|
||||
def _compatible_shapes(shape):
|
||||
if np.ndim(shape) == 0 or shape in scalar_shapes:
|
||||
return [shape]
|
||||
return (shape[n:] for n in range(len(shape) + 1))
|
||||
|
||||
def _get_y_shapes(y_dtype, shape, rowvar):
|
||||
# Helper function for testCov.
|
||||
if y_dtype is None:
|
||||
return [None]
|
||||
if len(shape) == 1:
|
||||
return [shape]
|
||||
elif rowvar or shape[0] == 1:
|
||||
return [(1, shape[-1]), (2, shape[-1]), (5, shape[-1])]
|
||||
return [(shape[0], 1), (shape[0], 2), (shape[0], 5)]
|
||||
|
||||
|
||||
OpRecord = collections.namedtuple(
|
||||
"OpRecord",
|
||||
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
|
||||
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])
|
||||
|
||||
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
||||
test_name=None, check_dtypes=True,
|
||||
tolerance=None, inexact=False, kwargs=None):
|
||||
test_name = test_name or name
|
||||
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
|
||||
test_name, check_dtypes, tolerance, inexact, kwargs)
|
||||
|
||||
JAX_REDUCER_RECORDS = [
|
||||
op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
|
||||
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_some_nan,
|
||||
[], inexact=True),
|
||||
op_record("nanprod", 1, all_dtypes, all_shapes, jtu.rand_some_nan, []),
|
||||
op_record("nansum", 1, number_dtypes, all_shapes, jtu.rand_some_nan, []),
|
||||
]
|
||||
|
||||
JAX_REDUCER_INITIAL_RECORDS = [
|
||||
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
|
||||
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("max", 1, all_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("min", 1, all_dtypes, all_shapes, jtu.rand_default, []),
|
||||
]
|
||||
if numpy_version >= (1, 22): # initial & where keywords added in numpy 1.22
|
||||
JAX_REDUCER_INITIAL_RECORDS += [
|
||||
op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_small_positive, []),
|
||||
op_record("nansum", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("nanmax", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("nanmin", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
|
||||
]
|
||||
|
||||
JAX_REDUCER_WHERE_NO_INITIAL_RECORDS = [
|
||||
op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
|
||||
op_record("any", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
|
||||
op_record("mean", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
]
|
||||
if numpy_version >= (1, 22): # where keyword added in numpy 1.22
|
||||
JAX_REDUCER_WHERE_NO_INITIAL_RECORDS += [
|
||||
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
]
|
||||
|
||||
JAX_REDUCER_NO_DTYPE_RECORDS = [
|
||||
op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
|
||||
op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
|
||||
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("nanmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
||||
op_record("nanmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
||||
op_record("nanvar", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
|
||||
[], inexact=True),
|
||||
op_record("nanstd", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
|
||||
[], inexact=True),
|
||||
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||
]
|
||||
|
||||
JAX_REDUCER_PROMOTE_INT_RECORDS = [
|
||||
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
|
||||
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
|
||||
]
|
||||
|
||||
|
||||
class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed Numpy reduction operations."""
|
||||
|
||||
def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
|
||||
def f():
|
||||
out = [rng(shape, dtype or jnp.float_)
|
||||
for shape, dtype in zip(shapes, dtypes)]
|
||||
if np_arrays:
|
||||
return out
|
||||
return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a
|
||||
for a in out]
|
||||
return f
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis,
|
||||
"None" if out_dtype is None else np.dtype(out_dtype).name, keepdims),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_RECORDS))
|
||||
def testReducer(self, np_op, jnp_op, rng_factory, shape, dtype, out_dtype,
|
||||
axis, keepdims, inexact):
|
||||
rng = rng_factory(self.rng())
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="mean of empty slice.*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="overflow encountered.*")
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32)
|
||||
t = out_dtype if out_dtype != jnp.bfloat16 else np.float32
|
||||
return np_op(x_cast, axis, dtype=t, keepdims=keepdims)
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol_spec = {np.float16: 1e-2, np.int32: 1E-3, np.float32: 1e-3,
|
||||
np.complex64: 1e-3, np.float64: 1e-5, np.complex128: 1e-5}
|
||||
tol = jtu.tolerance(dtype, tol_spec)
|
||||
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_NO_DTYPE_RECORDS))
|
||||
def testReducerNoDtype(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
keepdims, inexact):
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="All-NaN (slice|axis) encountered.*")
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol = {np.float16: 0.002}
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for initial in [0, 1] for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_INITIAL_RECORDS))
|
||||
def testReducerInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
keepdims, initial, inexact):
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol = {jnp.bfloat16: 3E-2}
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_promote_integers={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, promote_integers),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact,
|
||||
"promote_integers": promote_integers}
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for initial in [0, 1] for keepdims in [False, True]
|
||||
for promote_integers in [True, False]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_PROMOTE_INT_RECORDS))
|
||||
def testReducerPromoteInt(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
keepdims, initial, inexact, promote_integers):
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
print(f"res.dtype = {res.dtype}")
|
||||
if not promote_integers and dtypes.issubdtype(res.dtype, np.integer):
|
||||
res = res.astype(dtypes.to_numeric_dtype(x.dtype))
|
||||
return res
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol = {jnp.bfloat16: 3E-2}
|
||||
print(jnp_fun(*args_maker()))
|
||||
print(np_fun(*args_maker()))
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
|
||||
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
for shape in rec.shapes if np.prod(shape) == 0
|
||||
for dtype in rec.dtypes
|
||||
for keepdims in [False, True]
|
||||
for axis in range(-len(shape), len(shape)) if shape[axis] >= 1)
|
||||
for rec in JAX_REDUCER_INITIAL_RECORDS))
|
||||
def testReducerNoInitialZeroDims(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
keepdims, inexact):
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol = {jnp.bfloat16: 3E-2}
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial,
|
||||
jtu.format_shape_dtype_string(whereshape, bool)),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
|
||||
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for whereshape in _compatible_shapes(shape)
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for initial in [0, 1] for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_INITIAL_RECORDS))
|
||||
def testReducerWhere(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
keepdims, initial, inexact, whereshape):
|
||||
if (shape in [()] + scalar_shapes and
|
||||
dtype in [jnp.int16, jnp.uint16] and
|
||||
jnp_op in [jnp.min, jnp.max]):
|
||||
self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.")
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
|
||||
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format(
|
||||
rec.test_name.capitalize(),
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims,
|
||||
jtu.format_shape_dtype_string(whereshape, bool)),
|
||||
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
|
||||
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
|
||||
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
|
||||
for shape in rec.shapes for dtype in rec.dtypes
|
||||
for whereshape in _compatible_shapes(shape)
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
for keepdims in [False, True]
|
||||
if jtu.is_valid_shape(shape, dtype))
|
||||
for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS))
|
||||
def testReducerWhereNoInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
|
||||
keepdims, inexact, whereshape):
|
||||
rng = rng_factory(self.rng())
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16
|
||||
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
|
||||
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Mean of empty slice.*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="invalid value encountered.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, where=where)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testReductionOfOutOfBoundsAxis(self): # Issue 888
|
||||
x = jnp.ones((3, 4))
|
||||
self.assertRaises(ValueError, lambda: jnp.sum(x, axis=2))
|
||||
|
||||
def testReductionWithRepeatedAxisError(self):
|
||||
with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"):
|
||||
jnp.sum(jnp.arange(3), (0, 0))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
|
||||
.format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims),
|
||||
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
|
||||
"ddof": ddof, "keepdims": keepdims}
|
||||
for shape in [(5,), (10, 5)]
|
||||
for dtype in all_dtypes
|
||||
for out_dtype in inexact_dtypes
|
||||
for axis in [None, 0, -1]
|
||||
for ddof in [0, 1, 2]
|
||||
for keepdims in [False, True]))
|
||||
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
# Numpy fails with bfloat16 inputs
|
||||
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
||||
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
|
||||
axis=axis, ddof=ddof, keepdims=keepdims)
|
||||
return out.astype(out_dtype)
|
||||
jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
||||
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
|
||||
np.float64: 1e-3, np.complex128: 1e-6})
|
||||
if (jnp.issubdtype(dtype, jnp.complexfloating) and
|
||||
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
|
||||
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
|
||||
else:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
|
||||
.format(shape, dtype, out_dtype, axis, ddof, keepdims),
|
||||
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
|
||||
"ddof": ddof, "keepdims": keepdims}
|
||||
for shape in [(5,), (10, 5)]
|
||||
for dtype in all_dtypes
|
||||
for out_dtype in inexact_dtypes
|
||||
for axis in [None, 0, -1]
|
||||
for ddof in [0, 1, 2]
|
||||
for keepdims in [False, True]))
|
||||
def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
|
||||
rng = jtu.rand_some_nan(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
# Numpy fails with bfloat16 inputs
|
||||
out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
||||
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
|
||||
axis=axis, ddof=ddof, keepdims=keepdims)
|
||||
return out.astype(out_dtype)
|
||||
jnp_fun = partial(jnp.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
||||
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
|
||||
np.float64: 1e-3, np.complex128: 1e-6})
|
||||
if (jnp.issubdtype(dtype, jnp.complexfloating) and
|
||||
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
|
||||
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
|
||||
else:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
def testNanStdGrad(self):
|
||||
# Regression test for https://github.com/google/jax/issues/8128
|
||||
x = jnp.arange(5.0).at[0].set(jnp.nan)
|
||||
y = jax.grad(jnp.nanvar)(x)
|
||||
self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75]))
|
||||
|
||||
z = jax.grad(jnp.nanstd)(x)
|
||||
self.assertEqual(jnp.isnan(z).sum(), 0)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format(
|
||||
shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights),
|
||||
"shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype,"rowvar": rowvar, "ddof": ddof,
|
||||
"bias": bias, "fweights": fweights, "aweights": aweights}
|
||||
for shape in [(5,), (10, 5), (5, 10)]
|
||||
for dtype in all_dtypes
|
||||
for y_dtype in [None, dtype]
|
||||
for rowvar in [True, False]
|
||||
for y_shape in _get_y_shapes(y_dtype, shape, rowvar)
|
||||
for bias in [True, False]
|
||||
for ddof in [None, 2, 3]
|
||||
for fweights in [True, False]
|
||||
for aweights in [True, False]))
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
||||
def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
wrng = jtu.rand_positive(self.rng())
|
||||
wdtype = np.real(dtype(0)).dtype
|
||||
wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1]
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype),
|
||||
rng(y_shape, y_dtype) if y_dtype else None,
|
||||
wrng(wshape, int) if fweights else None,
|
||||
wrng(wshape, wdtype) if aweights else None]
|
||||
kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias)
|
||||
np_fun = lambda m, y, f, a: np.cov(m, y, fweights=f, aweights=a, **kwargs)
|
||||
jnp_fun = lambda m, y, f, a: jnp.cov(m, y, fweights=f, aweights=a, **kwargs)
|
||||
tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5,
|
||||
np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13}
|
||||
tol = 7e-2 if jtu.device_under_test() == "tpu" else tol
|
||||
tol = jtu.join_tolerance(tol, jtu.tolerance(dtype))
|
||||
self._CheckAgainstNumpy(
|
||||
np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user