diff --git a/tests/BUILD b/tests/BUILD index 4dd0404fd..cda4bd854 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -320,9 +320,31 @@ jax_test( srcs = ["lax_numpy_test.py"], pjrt_c_api_bypass = True, shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 20, + "cpu": 20, + "gpu": 20, + "tpu": 10, + }, +) + +jax_test( + name = "lax_numpy_operators_test", + srcs = ["lax_numpy_operators_test.py"], + pjrt_c_api_bypass = True, + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 5, + }, +) + +jax_test( + name = "lax_numpy_reducers_test", + srcs = ["lax_numpy_reducers_test.py"], + pjrt_c_api_bypass = True, + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 5, }, ) diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py new file mode 100644 index 000000000..0b3d4e4db --- /dev/null +++ b/tests/lax_numpy_operators_test.py @@ -0,0 +1,636 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import functools +from functools import partial +import itertools +import operator +from unittest import SkipTest + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np + +import jax +import jax.ops +from jax import lax +from jax import numpy as jnp +from jax import tree_util + +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact + +from jax.config import config +config.parse_flags_with_absl() +FLAGS = config.FLAGS + +nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] +nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes +one_dim_array_shapes = [(1,), (6,), (12,)] +empty_array_shapes = [(0,), (0, 4), (3, 0),] + +scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] +array_shapes = nonempty_array_shapes + empty_array_shapes +nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes +nonempty_shapes = scalar_shapes + nonempty_array_shapes +all_shapes = scalar_shapes + array_shapes + +float_dtypes = jtu.dtypes.all_floating +complex_dtypes = jtu.dtypes.complex +int_dtypes = jtu.dtypes.all_integer +unsigned_dtypes = jtu.dtypes.all_unsigned +bool_dtypes = jtu.dtypes.boolean +default_dtypes = float_dtypes + int_dtypes +inexact_dtypes = float_dtypes + complex_dtypes +number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes +all_dtypes = number_dtypes + bool_dtypes + + +python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_] + +# uint64 is problematic because with any uint type it promotes to float: +int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64] + +def _indexer_with_default_outputs(indexer, use_defaults=True): + """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" + class Indexer: + @partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults) + def __getitem__(self, *args): + return indexer.__getitem__(*args) + return Indexer() + +def _valid_dtypes_for_shape(shape, dtypes): + # Not all (shape, dtype) pairs are valid. In particular, Python scalars only + # have one type in each category (float, bool, etc.) + if shape is jtu.PYTHON_SCALAR_SHAPE: + return [t for t in dtypes if t in python_scalar_dtypes] + return dtypes + + +OpRecord = collections.namedtuple( + "OpRecord", + ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", + "test_name", "check_dtypes", "tolerance", "inexact", "kwargs"]) + +def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, + test_name=None, check_dtypes=True, + tolerance=None, inexact=False, kwargs=None): + test_name = test_name or name + return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, + test_name, check_dtypes, tolerance, inexact, kwargs) + +JAX_ONE_TO_ONE_OP_RECORDS = [ + op_record("abs", 1, all_dtypes, + all_shapes, jtu.rand_default, ["rev"]), + op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record("ceil", 1, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_default, [], check_dtypes=False), + op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), + op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + inexact=True), + op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("float_power", 2, inexact_dtypes, all_shapes, + partial(jtu.rand_default, scale=1), ["rev"], + tolerance={jnp.bfloat16: 1e-2, np.float32: 1e-3, + np.float64: 1e-12, np.complex64: 2e-4, + np.complex128: 1e-12}, check_dtypes=False), + op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record("floor", 1, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_default, [], check_dtypes=False), + op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), + op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), + op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [], + check_dtypes=False), + op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False), + op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), + op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), + op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], + inexact=True), + op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("maximum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("minimum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16], + all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0), + op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), + op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), + op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), + op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), + op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes, + jtu.rand_some_inf_and_nan, ["rev"]), + op_record("trunc", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("trunc", 1, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_some_inf_and_nan, [], check_dtypes=False), + op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + inexact=True), + op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + inexact=True), + op_record("tan", 1, number_dtypes, all_shapes, + partial(jtu.rand_uniform, low=-1.5, high=1.5), ["rev"], + inexact=True), + op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + inexact=True), + op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + inexact=True), + # TODO(b/142975473): on CPU, tanh for complex128 is only accurate to + # ~float32 precision. + # TODO(b/143135720): on GPU, tanh has only ~float32 precision. + op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + tolerance={np.float64: 1e-7, np.complex128: 1e-7}, + inexact=True), + op_record("arcsin", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"], + inexact=True), + op_record("arccos", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"], + inexact=True), + op_record("arctan", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"], + inexact=True), + op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small, ["rev"], + inexact=True), + op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + inexact=True, tolerance={np.complex64: 2E-4, np.complex128: 2E-14}), + op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + inexact=True, tolerance={np.complex64: 2E-2, np.complex128: 2E-12}), + op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"], + inexact=True, tolerance={np.float64: 1e-9}), +] + +JAX_COMPOUND_OP_RECORDS = [ + # angle has inconsistent 32/64-bit return types across numpy versions. + op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [], + check_dtypes=False, inexact=True), + op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [], + check_dtypes=False, inexact=True, test_name="angle_deg", kwargs={'deg': True}), + op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_some_inf, ["rev"], + inexact=True), + op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"], + inexact=True), + op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes, + jtu.rand_nonzero, []), + op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], + tolerance={jnp.bfloat16: 4e-2, np.float16: 1e-2}, inexact=True), + # TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32 + # precision. + op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [], + test_name="expm1_large", tolerance={np.float64: 1e-8}, inexact=True), + op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive, + [], tolerance={np.float64: 1e-8}, inexact=True), + op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record("fix", 1, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_default, [], check_dtypes=False), + op_record("floor_divide", 2, default_dtypes + unsigned_dtypes, + all_shapes, jtu.rand_nonzero, ["rev"]), + op_record("fmin", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []), + op_record("fmax", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []), + op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []), + op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [], + inexact=True), + op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [], + inexact=True), + op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []), + op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("isfinite", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("isinf", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("isnan", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], + inexact=True), + op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], + inexact=True), + op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive, [], + test_name="log1p_large", tolerance={np.float64: 1e-12}, + inexact=True), + op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [], + tolerance={np.float64: 1e-12}, inexact=True), + op_record("logaddexp", 2, float_dtypes, all_shapes, + jtu.rand_some_inf_and_nan, ["rev"], + tolerance={np.float64: 1e-12}, inexact=True), + op_record("logaddexp2", 2, float_dtypes, all_shapes, + jtu.rand_some_inf_and_nan, ["rev"], + tolerance={np.float16: 1e-2, np.float64: 2e-14}, inexact=True), + op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes, + jtu.rand_default, [], check_dtypes=False, + tolerance={dtypes.bfloat16: 4e-2, np.float16: 1e-2, + np.float64: 1e-12}), + op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], + tolerance={np.complex128: 1e-14}, check_dtypes=False), + op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [], + tolerance={np.float16: 1e-2}), + op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), + op_record("modf", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record("modf", 1, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_default, [], check_dtypes=False), + op_record("rint", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, + []), + op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_default, [], check_dtypes=False), + op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + # numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately. + op_record("copysign", 2, default_dtypes + unsigned_dtypes, + all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False), + op_record("sinc", 1, [t for t in number_dtypes if t != jnp.bfloat16], + all_shapes, jtu.rand_default, ["rev"], + tolerance={np.complex64: 1e-5}, inexact=True, + check_dtypes=False), + op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], + inexact=True), + op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"], + check_dtypes=False), + op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero, + ["rev"], inexact=True), + op_record("ediff1d", 3, [np.int32], all_shapes, jtu.rand_default, [], check_dtypes=False), + # TODO(phawkins): np.unwrap does not correctly promote its default period + # argument under NumPy 1.21 for bfloat16 inputs. It works fine if we + # explicitly pass a bfloat16 value that does not need promition. We should + # probably add a custom test harness for unwrap that tests the period + # argument anyway. + op_record("unwrap", 1, [t for t in float_dtypes if t != dtypes.bfloat16], + nonempty_nonscalar_array_shapes, + jtu.rand_default, ["rev"], + # numpy.unwrap always returns float64 + check_dtypes=False, + # numpy cumsum is inaccurate, see issue #3517 + tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1}), + op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16], + all_shapes, jtu.rand_small_positive, []), + op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []), + op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []), +] + +JAX_BITWISE_OP_RECORDS = [ + op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_fullrange, []), + op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_fullrange, []), + op_record("invert", 1, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_fullrange, []), + op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_fullrange, []), + op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_fullrange, []), +] + +JAX_OPERATOR_OVERLOADS = [ + op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__le__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [], + tolerance={np.float32: 2e-4, np.complex64: 2e-4, np.complex128: 1e-14}), + op_record("__mod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [], + tolerance={np.float16: 1e-1}), + op_record("__floordiv__", 2, default_dtypes, all_shapes, + jtu.rand_nonzero, []), + op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], + inexact=True), + op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []), + # TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2 + op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []), + # TODO(mattjj): investigate these failures + # op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + # op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), + op_record("__lshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []), + op_record("__rshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []), +] + +JAX_RIGHT_OPERATOR_OVERLOADS = [ + op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [], + tolerance={np.float32: 2e-4, np.complex64: 1e-3}), + op_record("__rmod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [], + tolerance={np.float16: 1e-1}), + op_record("__rfloordiv__", 2, default_dtypes, all_shapes, + jtu.rand_nonzero, []), + op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], + inexact=True), + # op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + # op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), + op_record("__rlshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []), + op_record("__rrshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []) +] + +class _OverrideEverything: + pass + +for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS: + if rec.nargs == 2: + setattr(_OverrideEverything, rec.name, lambda self, other: self) + +class _OverrideNothing: + pass + +for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS: + if rec.nargs == 2: + setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented) + + +def _dtypes_are_compatible_for_bitwise_ops(args): + if len(args) <= 1: + return True + is_signed = lambda dtype: jnp.issubdtype(dtype, np.signedinteger) + width = lambda dtype: jnp.iinfo(dtype).bits + x, y = args + if width(x) > width(y): + x, y = y, x + # The following condition seems a little ad hoc, but seems to capture what + # numpy actually implements. + return ( + is_signed(x) == is_signed(y) + or (width(x) == 32 and width(y) == 32) + or (width(x) == 32 and width(y) == 64 and is_signed(y))) + +def _shapes_are_broadcast_compatible(shapes): + try: + lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes)) + except ValueError: + return False + else: + return True + +def _shapes_are_equal_length(shapes): + return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) + + +def _promote_like_jnp(fun, inexact=False): + """Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`. + + jnp and np have different type promotion semantics; this decorator allows + tests make an np reference implementation act more like an jnp + implementation. + """ + _promote = _promote_dtypes_inexact if inexact else _promote_dtypes + def wrapper(*args, **kw): + flat_args, tree = tree_util.tree_flatten(args) + args = tree_util.tree_unflatten(tree, _promote(*flat_args)) + return fun(*args, **kw) + return wrapper + + +class JaxNumpyOperatorTests(jtu.JaxTestCase): + """Tests for LAX-backed Numpy operators.""" + + def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): + def f(): + out = [rng(shape, dtype or jnp.float_) + for shape, dtype in zip(shapes, dtypes)] + if np_arrays: + return out + return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a + for a in out] + return f + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, + dtypes), + "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, + "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), + "check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance, + "inexact": rec.inexact, "kwargs": rec.kwargs or {}} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(rec.shapes, rec.nargs)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) + for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, + JAX_COMPOUND_OP_RECORDS))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testOp(self, np_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes, + tolerance, inexact, kwargs): + np_op = partial(np_op, **kwargs) + jnp_op = partial(jnp_op, **kwargs) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="invalid value.*")(np_op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="divide by zero.*")(np_op) + + rng = rng_factory(self.rng()) + args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) + tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) + tol = functools.reduce(jtu.join_tolerance, + [tolerance, tol, jtu.default_tolerance()]) + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), jnp_op, + args_maker, check_dtypes=check_dtypes, tol=tol) + self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes, + atol=tol, rtol=tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, + dtypes), + "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, + "tol": rec.tolerance} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(rec.shapes, rec.nargs)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) + for rec in JAX_OPERATOR_OVERLOADS)) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): + rng = rng_factory(self.rng()) + # np and jnp arrays have different type promotion rules; force the use of + # jnp arrays. + args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) + fun = lambda *xs: getattr(operator, name.strip('_'))(*xs) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, + dtypes), + "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, + "op_tolerance": rec.tolerance} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(rec.shapes, rec.nargs)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) + for rec in JAX_RIGHT_OPERATOR_OVERLOADS)) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes, + op_tolerance): + if shapes[1] is jtu.PYTHON_SCALAR_SHAPE: + raise SkipTest("scalars not implemented") # TODO(mattjj): clean up + rng = rng_factory(self.rng()) + args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) + fun = lambda fst, snd: getattr(snd, name)(fst) + tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype} + for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2 + for othertype in [dict, list, tuple, set])) + def testOperatorOverloadErrors(self, name, othertype): + # Test that binary operators with builtin collections raise a TypeError + # and report the types in the correct order. + data = [(1, 2), (2, 3)] + arr = jnp.array(data) + other = othertype(data) + + if config.jax_array: + val_str = 'Array' + else: + val_str = 'DeviceArray' + msg = f"unsupported operand type.* '{val_str}' and '{othertype.__name__}'" + with self.assertRaisesRegex(TypeError, msg): + getattr(arr, name)(other) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype} + for rec in JAX_RIGHT_OPERATOR_OVERLOADS if rec.nargs == 2 + for othertype in [dict, list, tuple, set])) + def testRightOperatorOverloadErrors(self, name, othertype): + # Test that binary operators with builtin collections raise a TypeError + # and report the types in the correct order. + data = [(1, 2), (2, 3)] + arr = jnp.array(data) + other = othertype(data) + + if config.jax_array: + val_str = 'Array' + else: + val_str = 'DeviceArray' + msg = f"unsupported operand type.* '{othertype.__name__}' and '{val_str}'" + with self.assertRaisesRegex(TypeError, msg): + getattr(arr, name)(other) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": rec.test_name + f"_{dtype}", + "rng_factory": rec.rng_factory, + "op_name": rec.name, "dtype": dtype} + for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2 + for dtype in rec.dtypes)) + def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): + rng = rng_factory(self.rng()) + arg = jax.device_put(rng((), dtype)) + op = getattr(operator, op_name) + + other = _OverrideEverything() + assert op(other, arg) is other + assert op(arg, other) is other + + other = _OverrideNothing() + if op_name == "__eq__": + assert op(other, arg) is False + assert op(arg, other) is False + elif op_name == "__ne__": + assert op(other, arg) is True + assert op(arg, other) is True + else: + with self.assertRaises(TypeError): + op(other, arg) + with self.assertRaises(TypeError): + op(arg, other) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes), + "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, + "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name)} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(rec.shapes, rec.nargs)) + for dtypes in filter( + _dtypes_are_compatible_for_bitwise_ops, + itertools.combinations_with_replacement(rec.dtypes, rec.nargs))) + for rec in JAX_BITWISE_OP_RECORDS)) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes): + rng = rng_factory(self.rng()) + args_maker = self._GetArgsMaker(rng, shapes, dtypes) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix(op.__name__, shapes, dtypes), + "op": op, "dtypes": dtypes, "shapes": shapes} + for op in [jnp.left_shift, jnp.right_shift] + for shapes in filter( + _shapes_are_broadcast_compatible, + # TODO numpy always promotes to shift dtype for zero-dim shapes: + itertools.combinations_with_replacement(nonzerodim_shapes, 2)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes)))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testShiftOpAgainstNumpy(self, op, dtypes, shapes): + dtype, shift_dtype = dtypes + signed_mix = np.issubdtype(dtype, np.signedinteger) != \ + np.issubdtype(shift_dtype, np.signedinteger) + has_32 = any(np.iinfo(d).bits == 32 for d in dtypes) + promoting_to_64 = has_32 and signed_mix + if promoting_to_64 and not config.x64_enabled: + self.skipTest("np.right_shift/left_shift promoting to int64" + "differs from jnp in 32 bit mode.") + + info, shift_info = map(np.iinfo, dtypes) + x_rng = jtu.rand_int(self.rng(), low=info.min, high=info.max + 1) + # NumPy requires shifts to be non-negative and below the bit width: + shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits)) + args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype)) + + np_op = getattr(np, op.__name__) + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CompileAndCheck(op, args_maker) + self._CheckAgainstNumpy(np_op, op, args_maker) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py new file mode 100644 index 000000000..6444b81e4 --- /dev/null +++ b/tests/lax_numpy_reducers_test.py @@ -0,0 +1,584 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +from functools import partial +import itertools + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np + +import jax +from jax import numpy as jnp + +from jax._src import dtypes +from jax._src import test_util as jtu + +from jax.config import config +config.parse_flags_with_absl() +FLAGS = config.FLAGS + +numpy_version = tuple(map(int, np.__version__.split('.')[:3])) + +nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] +nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes +one_dim_array_shapes = [(1,), (6,), (12,)] +empty_array_shapes = [(0,), (0, 4), (3, 0),] + +scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] +array_shapes = nonempty_array_shapes + empty_array_shapes +nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes +nonempty_shapes = scalar_shapes + nonempty_array_shapes +all_shapes = scalar_shapes + array_shapes + +float_dtypes = jtu.dtypes.all_floating +complex_dtypes = jtu.dtypes.complex +int_dtypes = jtu.dtypes.all_integer +unsigned_dtypes = jtu.dtypes.all_unsigned +bool_dtypes = jtu.dtypes.boolean +default_dtypes = float_dtypes + int_dtypes +inexact_dtypes = float_dtypes + complex_dtypes +number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes +all_dtypes = number_dtypes + bool_dtypes + +python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_] + +def _compatible_shapes(shape): + if np.ndim(shape) == 0 or shape in scalar_shapes: + return [shape] + return (shape[n:] for n in range(len(shape) + 1)) + +def _get_y_shapes(y_dtype, shape, rowvar): + # Helper function for testCov. + if y_dtype is None: + return [None] + if len(shape) == 1: + return [shape] + elif rowvar or shape[0] == 1: + return [(1, shape[-1]), (2, shape[-1]), (5, shape[-1])] + return [(shape[0], 1), (shape[0], 2), (shape[0], 5)] + + +OpRecord = collections.namedtuple( + "OpRecord", + ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", + "test_name", "check_dtypes", "tolerance", "inexact", "kwargs"]) + +def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, + test_name=None, check_dtypes=True, + tolerance=None, inexact=False, kwargs=None): + test_name = test_name or name + return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, + test_name, check_dtypes, tolerance, inexact, kwargs) + +JAX_REDUCER_RECORDS = [ + op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), + op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), + op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_some_nan, + [], inexact=True), + op_record("nanprod", 1, all_dtypes, all_shapes, jtu.rand_some_nan, []), + op_record("nansum", 1, number_dtypes, all_shapes, jtu.rand_some_nan, []), +] + +JAX_REDUCER_INITIAL_RECORDS = [ + op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), + op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), + op_record("max", 1, all_dtypes, all_shapes, jtu.rand_default, []), + op_record("min", 1, all_dtypes, all_shapes, jtu.rand_default, []), +] +if numpy_version >= (1, 22): # initial & where keywords added in numpy 1.22 + JAX_REDUCER_INITIAL_RECORDS += [ + op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_small_positive, []), + op_record("nansum", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), + op_record("nanmax", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), + op_record("nanmin", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), + ] + +JAX_REDUCER_WHERE_NO_INITIAL_RECORDS = [ + op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []), + op_record("any", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []), + op_record("mean", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), +] +if numpy_version >= (1, 22): # where keyword added in numpy 1.22 + JAX_REDUCER_WHERE_NO_INITIAL_RECORDS += [ + op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + ] + +JAX_REDUCER_NO_DTYPE_RECORDS = [ + op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), + op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), + op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), + op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), + op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + op_record("nanmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []), + op_record("nanmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []), + op_record("nanvar", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, + [], inexact=True), + op_record("nanstd", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, + [], inexact=True), + op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []), +] + +JAX_REDUCER_PROMOTE_INT_RECORDS = [ + op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), + op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), +] + + +class JaxNumpyReducerTests(jtu.JaxTestCase): + """Tests for LAX-backed Numpy reduction operations.""" + + def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): + def f(): + out = [rng(shape, dtype or jnp.float_) + for shape, dtype in zip(shapes, dtypes)] + if np_arrays: + return out + return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a + for a in out] + return f + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, + "None" if out_dtype is None else np.dtype(out_dtype).name, keepdims), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, + "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), + "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes + for axis in list(range(-len(shape), len(shape))) + [None] + for keepdims in [False, True] + if jtu.is_valid_shape(shape, dtype)) + for rec in JAX_REDUCER_RECORDS)) + def testReducer(self, np_op, jnp_op, rng_factory, shape, dtype, out_dtype, + axis, keepdims, inexact): + rng = rng_factory(self.rng()) + @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=RuntimeWarning, + message="mean of empty slice.*") + @jtu.ignore_warning(category=RuntimeWarning, + message="overflow encountered.*") + def np_fun(x): + x = np.asarray(x) + if inexact: + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) + x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32) + t = out_dtype if out_dtype != jnp.bfloat16 else np.float32 + return np_op(x_cast, axis, dtype=t, keepdims=keepdims) + + jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + args_maker = lambda: [rng(shape, dtype)] + tol_spec = {np.float16: 1e-2, np.int32: 1E-3, np.float32: 1e-3, + np.complex64: 1e-3, np.float64: 1e-5, np.complex128: 1e-5} + tol = jtu.tolerance(dtype, tol_spec) + tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=jnp.bfloat16 not in (dtype, out_dtype), + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, + rtol=tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), + "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for axis in list(range(-len(shape), len(shape))) + [None] + for keepdims in [False, True] + if jtu.is_valid_shape(shape, dtype)) + for rec in JAX_REDUCER_NO_DTYPE_RECORDS)) + def testReducerNoDtype(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + keepdims, inexact): + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=RuntimeWarning, + message="All-NaN (slice|axis) encountered.*") + def np_fun(x): + x = np.asarray(x) + if inexact: + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + return res + + jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) + args_maker = lambda: [rng(shape, dtype)] + tol = {np.float16: 0.002} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), + "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for axis in list(range(-len(shape), len(shape))) + [None] + for initial in [0, 1] for keepdims in [False, True] + if jtu.is_valid_shape(shape, dtype)) + for rec in JAX_REDUCER_INITIAL_RECORDS)) + def testReducerInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + keepdims, initial, inexact): + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=np.ComplexWarning) + def np_fun(x): + x = np.asarray(x) + if inexact: + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims, initial=initial) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + return res + + jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + args_maker = lambda: [rng(shape, dtype)] + tol = {jnp.bfloat16: 3E-2} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_promote_integers={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, promote_integers), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), + "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact, + "promote_integers": promote_integers} + for shape in rec.shapes for dtype in rec.dtypes + for axis in list(range(-len(shape), len(shape))) + [None] + for initial in [0, 1] for keepdims in [False, True] + for promote_integers in [True, False] + if jtu.is_valid_shape(shape, dtype)) + for rec in JAX_REDUCER_PROMOTE_INT_RECORDS)) + def testReducerPromoteInt(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + keepdims, initial, inexact, promote_integers): + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=np.ComplexWarning) + def np_fun(x): + x = np.asarray(x) + if inexact: + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims, initial=initial) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + print(f"res.dtype = {res.dtype}") + if not promote_integers and dtypes.issubdtype(res.dtype, np.integer): + res = res.astype(dtypes.to_numeric_dtype(x.dtype)) + return res + + jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + args_maker = lambda: [rng(shape, dtype)] + tol = {jnp.bfloat16: 3E-2} + print(jnp_fun(*args_maker())) + print(np_fun(*args_maker())) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), + "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes if np.prod(shape) == 0 + for dtype in rec.dtypes + for keepdims in [False, True] + for axis in range(-len(shape), len(shape)) if shape[axis] >= 1) + for rec in JAX_REDUCER_INITIAL_RECORDS)) + def testReducerNoInitialZeroDims(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + keepdims, inexact): + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=np.ComplexWarning) + def np_fun(x): + x = np.asarray(x) + if inexact: + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + return res + + jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + args_maker = lambda: [rng(shape, dtype)] + tol = {jnp.bfloat16: 3E-2} + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, + jtu.format_shape_dtype_string(whereshape, bool)), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape, + "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for whereshape in _compatible_shapes(shape) + for axis in list(range(-len(shape), len(shape))) + [None] + for initial in [0, 1] for keepdims in [False, True] + if jtu.is_valid_shape(shape, dtype)) + for rec in JAX_REDUCER_INITIAL_RECORDS)) + def testReducerWhere(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + keepdims, initial, inexact, whereshape): + if (shape in [()] + scalar_shapes and + dtype in [jnp.int16, jnp.uint16] and + jnp_op in [jnp.min, jnp.max]): + self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.") + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. + where = jtu.rand_bool(self.rng())(whereshape, np.bool_) + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=np.ComplexWarning) + def np_fun(x): + x = np.asarray(x) + if inexact: + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + return res + + jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, + jtu.format_shape_dtype_string(whereshape, bool)), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape, + "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for whereshape in _compatible_shapes(shape) + for axis in list(range(-len(shape), len(shape))) + [None] + for keepdims in [False, True] + if jtu.is_valid_shape(shape, dtype)) + for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)) + def testReducerWhereNoInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis, + keepdims, inexact, whereshape): + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 + # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. + where = jtu.rand_bool(self.rng())(whereshape, np.bool_) + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=RuntimeWarning, + message="Mean of empty slice.*") + @jtu.ignore_warning(category=RuntimeWarning, + message="invalid value encountered.*") + @jtu.ignore_warning(category=np.ComplexWarning) + def np_fun(x): + x = np.asarray(x) + if inexact: + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims, where=where) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + return res + + jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where) + jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + args_maker = lambda: [rng(shape, dtype)] + if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"): + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testReductionOfOutOfBoundsAxis(self): # Issue 888 + x = jnp.ones((3, 4)) + self.assertRaises(ValueError, lambda: jnp.sum(x, axis=2)) + + def testReductionWithRepeatedAxisError(self): + with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"): + jnp.sum(jnp.arange(3), (0, 0)) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": + "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" + .format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims), + "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, + "ddof": ddof, "keepdims": keepdims} + for shape in [(5,), (10, 5)] + for dtype in all_dtypes + for out_dtype in inexact_dtypes + for axis in [None, 0, -1] + for ddof in [0, 1, 2] + for keepdims in [False, True])) + def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.") + @jtu.ignore_warning(category=np.ComplexWarning) + def np_fun(x): + # Numpy fails with bfloat16 inputs + out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), + dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype, + axis=axis, ddof=ddof, keepdims=keepdims) + return out.astype(out_dtype) + jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) + tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, + np.float64: 1e-3, np.complex128: 1e-6}) + if (jnp.issubdtype(dtype, jnp.complexfloating) and + not jnp.issubdtype(out_dtype, jnp.complexfloating)): + self.assertRaises(ValueError, lambda: jnp_fun(*args_maker())) + else: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, + atol=tol) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": + "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" + .format(shape, dtype, out_dtype, axis, ddof, keepdims), + "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, + "ddof": ddof, "keepdims": keepdims} + for shape in [(5,), (10, 5)] + for dtype in all_dtypes + for out_dtype in inexact_dtypes + for axis in [None, 0, -1] + for ddof in [0, 1, 2] + for keepdims in [False, True])) + def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): + rng = jtu.rand_some_nan(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.") + @jtu.ignore_warning(category=np.ComplexWarning) + def np_fun(x): + # Numpy fails with bfloat16 inputs + out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), + dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype, + axis=axis, ddof=ddof, keepdims=keepdims) + return out.astype(out_dtype) + jnp_fun = partial(jnp.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) + tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, + np.float64: 1e-3, np.complex128: 1e-6}) + if (jnp.issubdtype(dtype, jnp.complexfloating) and + not jnp.issubdtype(out_dtype, jnp.complexfloating)): + self.assertRaises(ValueError, lambda: jnp_fun(*args_maker())) + else: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, + atol=tol) + + def testNanStdGrad(self): + # Regression test for https://github.com/google/jax/issues/8128 + x = jnp.arange(5.0).at[0].set(jnp.nan) + y = jax.grad(jnp.nanvar)(x) + self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75])) + + z = jax.grad(jnp.nanstd)(x) + self.assertEqual(jnp.isnan(z).sum(), 0) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": + "_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format( + shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights), + "shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype,"rowvar": rowvar, "ddof": ddof, + "bias": bias, "fweights": fweights, "aweights": aweights} + for shape in [(5,), (10, 5), (5, 10)] + for dtype in all_dtypes + for y_dtype in [None, dtype] + for rowvar in [True, False] + for y_shape in _get_y_shapes(y_dtype, shape, rowvar) + for bias in [True, False] + for ddof in [None, 2, 3] + for fweights in [True, False] + for aweights in [True, False])) + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights): + rng = jtu.rand_default(self.rng()) + wrng = jtu.rand_positive(self.rng()) + wdtype = np.real(dtype(0)).dtype + wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1] + + args_maker = lambda: [rng(shape, dtype), + rng(y_shape, y_dtype) if y_dtype else None, + wrng(wshape, int) if fweights else None, + wrng(wshape, wdtype) if aweights else None] + kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias) + np_fun = lambda m, y, f, a: np.cov(m, y, fweights=f, aweights=a, **kwargs) + jnp_fun = lambda m, y, f, a: jnp.cov(m, y, fweights=f, aweights=a, **kwargs) + tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5, + np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13} + tol = 7e-2 if jtu.device_under_test() == "tpu" else tol + tol = jtu.join_tolerance(tol, jtu.tolerance(dtype)) + self._CheckAgainstNumpy( + np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, + rtol=tol) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 0619af632..18c526f44 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -15,12 +15,10 @@ import collections import copy -import functools from functools import partial import inspect import io import itertools -import operator from typing import cast, Iterator, Optional, List, Tuple import unittest from unittest import SkipTest @@ -109,17 +107,6 @@ def _compatible_shapes(shape): return [shape] return (shape[n:] for n in range(len(shape) + 1)) -def _get_y_shapes(y_dtype, shape, rowvar): - # Helper function for testCov. - if y_dtype is None: - return [None] - if len(shape) == 1: - return [shape] - elif rowvar or shape[0] == 1: - return [(1, shape[-1]), (2, shape[-1]), (5, shape[-1])] - return [(shape[0], 1), (shape[0], 2), (shape[0], 5)] - - OpRecord = collections.namedtuple( "OpRecord", ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", @@ -132,285 +119,6 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, test_name, check_dtypes, tolerance, inexact, kwargs) -JAX_ONE_TO_ONE_OP_RECORDS = [ - op_record("abs", 1, all_dtypes, - all_shapes, jtu.rand_default, ["rev"]), - op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("ceil", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, [], check_dtypes=False), - op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), - op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("float_power", 2, inexact_dtypes, all_shapes, - partial(jtu.rand_default, scale=1), ["rev"], - tolerance={jnp.bfloat16: 1e-2, np.float32: 1e-3, - np.float64: 1e-12, np.complex64: 2e-4, - np.complex128: 1e-12}, check_dtypes=False), - op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("floor", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, [], check_dtypes=False), - op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), - op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), - op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [], - check_dtypes=False), - op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False), - op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), - op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), - op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("maximum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("minimum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16], - all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0), - op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), - op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), - op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), - op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), - op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, ["rev"]), - op_record("trunc", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("trunc", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, [], check_dtypes=False), - op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("tan", 1, number_dtypes, all_shapes, - partial(jtu.rand_uniform, low=-1.5, high=1.5), ["rev"], - inexact=True), - op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - # TODO(b/142975473): on CPU, tanh for complex128 is only accurate to - # ~float32 precision. - # TODO(b/143135720): on GPU, tanh has only ~float32 precision. - op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - tolerance={np.float64: 1e-7, np.complex128: 1e-7}, - inexact=True), - op_record("arcsin", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arccos", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arctan", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True, tolerance={np.complex64: 2E-4, np.complex128: 2E-14}), - op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True, tolerance={np.complex64: 2E-2, np.complex128: 2E-12}), - op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"], - inexact=True, tolerance={np.float64: 1e-9}), -] - -JAX_COMPOUND_OP_RECORDS = [ - # angle has inconsistent 32/64-bit return types across numpy versions. - op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [], - check_dtypes=False, inexact=True), - op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [], - check_dtypes=False, inexact=True, test_name="angle_deg", kwargs={'deg': True}), - op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_some_inf, ["rev"], - inexact=True), - op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"], - inexact=True), - op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes, - jtu.rand_nonzero, []), - op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - tolerance={jnp.bfloat16: 4e-2, np.float16: 1e-2}, inexact=True), - # TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32 - # precision. - op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [], - test_name="expm1_large", tolerance={np.float64: 1e-8}, inexact=True), - op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive, - [], tolerance={np.float64: 1e-8}, inexact=True), - op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("fix", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, [], check_dtypes=False), - op_record("floor_divide", 2, default_dtypes + unsigned_dtypes, - all_shapes, jtu.rand_nonzero, ["rev"]), - op_record("fmin", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []), - op_record("fmax", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []), - op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []), - op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [], - inexact=True), - op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [], - inexact=True), - op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []), - op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("isfinite", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("isinf", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("isnan", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive, [], - test_name="log1p_large", tolerance={np.float64: 1e-12}, - inexact=True), - op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [], - tolerance={np.float64: 1e-12}, inexact=True), - op_record("logaddexp", 2, float_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, ["rev"], - tolerance={np.float64: 1e-12}, inexact=True), - op_record("logaddexp2", 2, float_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, ["rev"], - tolerance={np.float16: 1e-2, np.float64: 2e-14}, inexact=True), - op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes, - jtu.rand_default, [], check_dtypes=False, - tolerance={dtypes.bfloat16: 4e-2, np.float16: 1e-2, - np.float64: 1e-12}), - op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - tolerance={np.complex128: 1e-14}, check_dtypes=False), - op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [], - tolerance={np.float16: 1e-2}), - op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), - op_record("modf", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("modf", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, [], check_dtypes=False), - op_record("rint", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, - []), - op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, [], check_dtypes=False), - op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - # numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately. - op_record("copysign", 2, default_dtypes + unsigned_dtypes, - all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False), - op_record("sinc", 1, [t for t in number_dtypes if t != jnp.bfloat16], - all_shapes, jtu.rand_default, ["rev"], - tolerance={np.complex64: 1e-5}, inexact=True, - check_dtypes=False), - op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"], - check_dtypes=False), - op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero, - ["rev"], inexact=True), - op_record("ediff1d", 3, [np.int32], all_shapes, jtu.rand_default, [], check_dtypes=False), - # TODO(phawkins): np.unwrap does not correctly promote its default period - # argument under NumPy 1.21 for bfloat16 inputs. It works fine if we - # explicitly pass a bfloat16 value that does not need promition. We should - # probably add a custom test harness for unwrap that tests the period - # argument anyway. - op_record("unwrap", 1, [t for t in float_dtypes if t != dtypes.bfloat16], - nonempty_nonscalar_array_shapes, - jtu.rand_default, ["rev"], - # numpy.unwrap always returns float64 - check_dtypes=False, - # numpy cumsum is inaccurate, see issue #3517 - tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1}), - op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16], - all_shapes, jtu.rand_small_positive, []), - op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []), - op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []), -] - -JAX_BITWISE_OP_RECORDS = [ - op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_fullrange, []), - op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_fullrange, []), - op_record("invert", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_fullrange, []), - op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_fullrange, []), - op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_fullrange, []), -] - -JAX_REDUCER_RECORDS = [ - op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), - op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), - op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_some_nan, - [], inexact=True), - op_record("nanprod", 1, all_dtypes, all_shapes, jtu.rand_some_nan, []), - op_record("nansum", 1, number_dtypes, all_shapes, jtu.rand_some_nan, []), -] - -JAX_REDUCER_INITIAL_RECORDS = [ - op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), - op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), - op_record("max", 1, all_dtypes, all_shapes, jtu.rand_default, []), - op_record("min", 1, all_dtypes, all_shapes, jtu.rand_default, []), -] -if numpy_version >= (1, 22): # initial & where keywords added in numpy 1.22 - JAX_REDUCER_INITIAL_RECORDS += [ - op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_small_positive, []), - op_record("nansum", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), - op_record("nanmax", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), - op_record("nanmin", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), - ] - -JAX_REDUCER_WHERE_NO_INITIAL_RECORDS = [ - op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []), - op_record("any", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []), - op_record("mean", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), -] -if numpy_version >= (1, 22): # where keyword added in numpy 1.22 - JAX_REDUCER_WHERE_NO_INITIAL_RECORDS += [ - op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - ] - -JAX_REDUCER_NO_DTYPE_RECORDS = [ - op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), - op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), - op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), - op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []), - op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("nanmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []), - op_record("nanmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []), - op_record("nanvar", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, - [], inexact=True), - op_record("nanstd", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, - [], inexact=True), - op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []), -] - -JAX_REDUCER_PROMOTE_INT_RECORDS = [ - op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), - op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), -] JAX_ARGMINMAX_RECORDS = [ op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []), @@ -419,88 +127,6 @@ JAX_ARGMINMAX_RECORDS = [ op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []), ] -JAX_OPERATOR_OVERLOADS = [ - op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__le__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [], - tolerance={np.float32: 2e-4, np.complex64: 2e-4, np.complex128: 1e-14}), - op_record("__mod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [], - tolerance={np.float16: 1e-1}), - op_record("__floordiv__", 2, default_dtypes, all_shapes, - jtu.rand_nonzero, []), - op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], - inexact=True), - op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []), - # TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2 - op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []), - # TODO(mattjj): investigate these failures - # op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - # op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), - op_record("__lshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []), - op_record("__rshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []), -] - -JAX_RIGHT_OPERATOR_OVERLOADS = [ - op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [], - tolerance={np.float32: 2e-4, np.complex64: 1e-3}), - op_record("__rmod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [], - tolerance={np.float16: 1e-1}), - op_record("__rfloordiv__", 2, default_dtypes, all_shapes, - jtu.rand_nonzero, []), - op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], - inexact=True), - # op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - # op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), - op_record("__rlshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []), - op_record("__rrshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []) -] - -class _OverrideEverything: - pass - -for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS: - if rec.nargs == 2: - setattr(_OverrideEverything, rec.name, lambda self, other: self) - -class _OverrideNothing: - pass - -for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS: - if rec.nargs == 2: - setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented) - - -def _dtypes_are_compatible_for_bitwise_ops(args): - if len(args) <= 1: - return True - is_signed = lambda dtype: jnp.issubdtype(dtype, np.signedinteger) - width = lambda dtype: jnp.iinfo(dtype).bits - x, y = args - if width(x) > width(y): - x, y = y, x - # The following condition seems a little ad hoc, but seems to capture what - # numpy actually implements. - return ( - is_signed(x) == is_signed(y) - or (width(x) == 32 and width(y) == 32) - or (width(x) == 32 and width(y) == 64 and is_signed(y))) - def _shapes_are_broadcast_compatible(shapes): try: lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes)) @@ -581,155 +207,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): arr_out = jnp.load(f, allow_pickle=allow_pickle) self.assertArraysEqual(arr, arr_out) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance, - "inexact": rec.inexact, "kwargs": rec.kwargs or {}} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, - JAX_COMPOUND_OP_RECORDS))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testOp(self, np_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes, - tolerance, inexact, kwargs): - np_op = partial(np_op, **kwargs) - jnp_op = partial(jnp_op, **kwargs) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="invalid value.*")(np_op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="divide by zero.*")(np_op) - - rng = rng_factory(self.rng()) - args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) - tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) - tol = functools.reduce(jtu.join_tolerance, - [tolerance, tol, jtu.default_tolerance()]) - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), jnp_op, - args_maker, check_dtypes=check_dtypes, tol=tol) - self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, - "tol": rec.tolerance} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in JAX_OPERATOR_OVERLOADS)) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): - rng = rng_factory(self.rng()) - # np and jnp arrays have different type promotion rules; force the use of - # jnp arrays. - args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) - fun = lambda *xs: getattr(operator, name.strip('_'))(*xs) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, - "op_tolerance": rec.tolerance} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in JAX_RIGHT_OPERATOR_OVERLOADS)) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes, - op_tolerance): - if shapes[1] is jtu.PYTHON_SCALAR_SHAPE: - raise SkipTest("scalars not implemented") # TODO(mattjj): clean up - rng = rng_factory(self.rng()) - args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) - fun = lambda fst, snd: getattr(snd, name)(fst) - tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype} - for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2 - for othertype in [dict, list, tuple, set])) - def testOperatorOverloadErrors(self, name, othertype): - # Test that binary operators with builtin collections raise a TypeError - # and report the types in the correct order. - data = [(1, 2), (2, 3)] - arr = jnp.array(data) - other = othertype(data) - - if config.jax_array: - val_str = 'Array' - else: - val_str = 'DeviceArray' - msg = f"unsupported operand type.* '{val_str}' and '{othertype.__name__}'" - with self.assertRaisesRegex(TypeError, msg): - getattr(arr, name)(other) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype} - for rec in JAX_RIGHT_OPERATOR_OVERLOADS if rec.nargs == 2 - for othertype in [dict, list, tuple, set])) - def testRightOperatorOverloadErrors(self, name, othertype): - # Test that binary operators with builtin collections raise a TypeError - # and report the types in the correct order. - data = [(1, 2), (2, 3)] - arr = jnp.array(data) - other = othertype(data) - - if config.jax_array: - val_str = 'Array' - else: - val_str = 'DeviceArray' - msg = f"unsupported operand type.* '{othertype.__name__}' and '{val_str}'" - with self.assertRaisesRegex(TypeError, msg): - getattr(arr, name)(other) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": rec.test_name + f"_{dtype}", - "rng_factory": rec.rng_factory, - "op_name": rec.name, "dtype": dtype} - for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2 - for dtype in rec.dtypes)) - def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): - rng = rng_factory(self.rng()) - arg = jax.device_put(rng((), dtype)) - op = getattr(operator, op_name) - - other = _OverrideEverything() - assert op(other, arg) is other - assert op(arg, other) is other - - other = _OverrideNothing() - if op_name == "__eq__": - assert op(other, arg) is False - assert op(arg, other) is False - elif op_name == "__ne__": - assert op(other, arg) is True - assert op(arg, other) is True - else: - with self.assertRaises(TypeError): - op(other, arg) - with self.assertRaises(TypeError): - op(arg, other) - def testArrayEqualExamples(self): # examples from the array_equal() docstring. self.assertTrue(jnp.array_equal([1, 2], [1, 2])) @@ -776,342 +253,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jax.jit(f)(jnp_array) jax.grad(f)(jnp_array) - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.test_name, shapes, dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name)} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(rec.shapes, rec.nargs)) - for dtypes in filter( - _dtypes_are_compatible_for_bitwise_ops, - itertools.combinations_with_replacement(rec.dtypes, rec.nargs))) - for rec in JAX_BITWISE_OP_RECORDS)) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes): - rng = rng_factory(self.rng()) - args_maker = self._GetArgsMaker(rng, shapes, dtypes) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(op.__name__, shapes, dtypes), - "op": op, "dtypes": dtypes, "shapes": shapes} - for op in [jnp.left_shift, jnp.right_shift] - for shapes in filter( - _shapes_are_broadcast_compatible, - # TODO numpy always promotes to shift dtype for zero-dim shapes: - itertools.combinations_with_replacement(nonzerodim_shapes, 2)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes)))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testShiftOpAgainstNumpy(self, op, dtypes, shapes): - dtype, shift_dtype = dtypes - signed_mix = np.issubdtype(dtype, np.signedinteger) != \ - np.issubdtype(shift_dtype, np.signedinteger) - has_32 = any(np.iinfo(d).bits == 32 for d in dtypes) - promoting_to_64 = has_32 and signed_mix - if promoting_to_64 and not config.x64_enabled: - self.skipTest("np.right_shift/left_shift promoting to int64" - "differs from jnp in 32 bit mode.") - - info, shift_info = map(np.iinfo, dtypes) - x_rng = jtu.rand_int(self.rng(), low=info.min, high=info.max + 1) - # NumPy requires shifts to be non-negative and below the bit width: - shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits)) - args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype)) - - np_op = getattr(np, op.__name__) - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CompileAndCheck(op, args_maker) - self._CheckAgainstNumpy(np_op, op, args_maker) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, - "None" if out_dtype is None else np.dtype(out_dtype).name, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes - for axis in list(range(-len(shape), len(shape))) + [None] - for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_RECORDS)) - def testReducer(self, np_op, jnp_op, rng_factory, shape, dtype, out_dtype, - axis, keepdims, inexact): - rng = rng_factory(self.rng()) - @jtu.ignore_warning(category=np.ComplexWarning) - @jtu.ignore_warning(category=RuntimeWarning, - message="mean of empty slice.*") - @jtu.ignore_warning(category=RuntimeWarning, - message="overflow encountered.*") - def np_fun(x): - x = np.asarray(x) - if inexact: - x = x.astype(dtypes.to_inexact_dtype(x.dtype)) - x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32) - t = out_dtype if out_dtype != jnp.bfloat16 else np.float32 - return np_op(x_cast, axis, dtype=t, keepdims=keepdims) - - jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {np.float16: 1e-2, np.int32: 1E-3, np.float32: 1e-3, - np.complex64: 1e-3, np.float64: 1e-5, np.complex128: 1e-5} - tol = jtu.tolerance(dtype, tol_spec) - tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=jnp.bfloat16 not in (dtype, out_dtype), - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, atol=tol, - rtol=tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for axis in list(range(-len(shape), len(shape))) + [None] - for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_NO_DTYPE_RECORDS)) - def testReducerNoDtype(self, np_op, jnp_op, rng_factory, shape, dtype, axis, - keepdims, inexact): - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=RuntimeWarning, - message="All-NaN (slice|axis) encountered.*") - def np_fun(x): - x = np.asarray(x) - if inexact: - x = x.astype(dtypes.to_inexact_dtype(x.dtype)) - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - return res - - jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) - args_maker = lambda: [rng(shape, dtype)] - tol = {np.float16: 0.002} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for axis in list(range(-len(shape), len(shape))) + [None] - for initial in [0, 1] for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_INITIAL_RECORDS)) - def testReducerInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis, - keepdims, initial, inexact): - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=np.ComplexWarning) - def np_fun(x): - x = np.asarray(x) - if inexact: - x = x.astype(dtypes.to_inexact_dtype(x.dtype)) - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims, initial=initial) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - return res - - jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) - args_maker = lambda: [rng(shape, dtype)] - tol = {jnp.bfloat16: 3E-2} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_promote_integers={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, promote_integers), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact, - "promote_integers": promote_integers} - for shape in rec.shapes for dtype in rec.dtypes - for axis in list(range(-len(shape), len(shape))) + [None] - for initial in [0, 1] for keepdims in [False, True] - for promote_integers in [True, False] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_PROMOTE_INT_RECORDS)) - def testReducerPromoteInt(self, np_op, jnp_op, rng_factory, shape, dtype, axis, - keepdims, initial, inexact, promote_integers): - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=np.ComplexWarning) - def np_fun(x): - x = np.asarray(x) - if inexact: - x = x.astype(dtypes.to_inexact_dtype(x.dtype)) - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims, initial=initial) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - print(f"res.dtype = {res.dtype}") - if not promote_integers and dtypes.issubdtype(res.dtype, np.integer): - res = res.astype(dtypes.to_numeric_dtype(x.dtype)) - return res - - jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) - args_maker = lambda: [rng(shape, dtype)] - tol = {jnp.bfloat16: 3E-2} - print(jnp_fun(*args_maker())) - print(np_fun(*args_maker())) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes if np.prod(shape) == 0 - for dtype in rec.dtypes - for keepdims in [False, True] - for axis in range(-len(shape), len(shape)) if shape[axis] >= 1) - for rec in JAX_REDUCER_INITIAL_RECORDS)) - def testReducerNoInitialZeroDims(self, np_op, jnp_op, rng_factory, shape, dtype, axis, - keepdims, inexact): - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=np.ComplexWarning) - def np_fun(x): - x = np.asarray(x) - if inexact: - x = x.astype(dtypes.to_inexact_dtype(x.dtype)) - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - return res - - jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) - args_maker = lambda: [rng(shape, dtype)] - tol = {jnp.bfloat16: 3E-2} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, - jtu.format_shape_dtype_string(whereshape, bool)), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape, - "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for whereshape in _compatible_shapes(shape) - for axis in list(range(-len(shape), len(shape))) + [None] - for initial in [0, 1] for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_INITIAL_RECORDS)) - def testReducerWhere(self, np_op, jnp_op, rng_factory, shape, dtype, axis, - keepdims, initial, inexact, whereshape): - if (shape in [()] + scalar_shapes and - dtype in [jnp.int16, jnp.uint16] and - jnp_op in [jnp.min, jnp.max]): - self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.") - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' - # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. - where = jtu.rand_bool(self.rng())(whereshape, np.bool_) - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=np.ComplexWarning) - def np_fun(x): - x = np.asarray(x) - if inexact: - x = x.astype(dtypes.to_inexact_dtype(x.dtype)) - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - return res - - jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, - jtu.format_shape_dtype_string(whereshape, bool)), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape, - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for whereshape in _compatible_shapes(shape) - for axis in list(range(-len(shape), len(shape))) + [None] - for keepdims in [False, True] - if jtu.is_valid_shape(shape, dtype)) - for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)) - def testReducerWhereNoInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis, - keepdims, inexact, whereshape): - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 - # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. - where = jtu.rand_bool(self.rng())(whereshape, np.bool_) - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=RuntimeWarning, - message="Mean of empty slice.*") - @jtu.ignore_warning(category=RuntimeWarning, - message="invalid value encountered.*") - @jtu.ignore_warning(category=np.ComplexWarning) - def np_fun(x): - x = np.asarray(x) - if inexact: - x = x.astype(dtypes.to_inexact_dtype(x.dtype)) - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims, where=where) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - return res - - jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) - args_maker = lambda: [rng(shape, dtype)] - if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}_discont={}_period={}".format( @@ -5444,136 +4585,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertAllClose(expected, actual, atol=tol, rtol=tol) - - def testReductionOfOutOfBoundsAxis(self): # Issue 888 - x = jnp.ones((3, 4)) - self.assertRaises(ValueError, lambda: jnp.sum(x, axis=2)) - def testIssue956(self): self.assertRaises(TypeError, lambda: jnp.ndarray((1, 1))) - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" - .format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims), - "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, - "ddof": ddof, "keepdims": keepdims} - for shape in [(5,), (10, 5)] - for dtype in all_dtypes - for out_dtype in inexact_dtypes - for axis in [None, 0, -1] - for ddof in [0, 1, 2] - for keepdims in [False, True])) - def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=np.ComplexWarning) - def np_fun(x): - # Numpy fails with bfloat16 inputs - out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), - dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype, - axis=axis, ddof=ddof, keepdims=keepdims) - return out.astype(out_dtype) - jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) - tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, - np.float64: 1e-3, np.complex128: 1e-6}) - if (jnp.issubdtype(dtype, jnp.complexfloating) and - not jnp.issubdtype(out_dtype, jnp.complexfloating)): - self.assertRaises(ValueError, lambda: jnp_fun(*args_maker())) - else: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, - atol=tol) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" - .format(shape, dtype, out_dtype, axis, ddof, keepdims), - "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, - "ddof": ddof, "keepdims": keepdims} - for shape in [(5,), (10, 5)] - for dtype in all_dtypes - for out_dtype in inexact_dtypes - for axis in [None, 0, -1] - for ddof in [0, 1, 2] - for keepdims in [False, True])) - def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): - rng = jtu.rand_some_nan(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=np.ComplexWarning) - def np_fun(x): - # Numpy fails with bfloat16 inputs - out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), - dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype, - axis=axis, ddof=ddof, keepdims=keepdims) - return out.astype(out_dtype) - jnp_fun = partial(jnp.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) - tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, - np.float64: 1e-3, np.complex128: 1e-6}) - if (jnp.issubdtype(dtype, jnp.complexfloating) and - not jnp.issubdtype(out_dtype, jnp.complexfloating)): - self.assertRaises(ValueError, lambda: jnp_fun(*args_maker())) - else: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, - atol=tol) - - def testNanStdGrad(self): - # Regression test for https://github.com/google/jax/issues/8128 - x = jnp.arange(5.0).at[0].set(jnp.nan) - y = jax.grad(jnp.nanvar)(x) - self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75])) - - z = jax.grad(jnp.nanstd)(x) - self.assertEqual(jnp.isnan(z).sum(), 0) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format( - shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights), - "shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype,"rowvar": rowvar, "ddof": ddof, - "bias": bias, "fweights": fweights, "aweights": aweights} - for shape in [(5,), (10, 5), (5, 10)] - for dtype in all_dtypes - for y_dtype in [None, dtype] - for rowvar in [True, False] - for y_shape in _get_y_shapes(y_dtype, shape, rowvar) - for bias in [True, False] - for ddof in [None, 2, 3] - for fweights in [True, False] - for aweights in [True, False])) - @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion - def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights): - rng = jtu.rand_default(self.rng()) - wrng = jtu.rand_positive(self.rng()) - wdtype = np.real(dtype(0)).dtype - wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1] - - args_maker = lambda: [rng(shape, dtype), - rng(y_shape, y_dtype) if y_dtype else None, - wrng(wshape, int) if fweights else None, - wrng(wshape, wdtype) if aweights else None] - kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias) - np_fun = lambda m, y, f, a: np.cov(m, y, fweights=f, aweights=a, **kwargs) - jnp_fun = lambda m, y, f, a: jnp.cov(m, y, fweights=f, aweights=a, **kwargs) - tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5, - np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13} - tol = 7e-2 if jtu.device_under_test() == "tpu" else tol - tol = jtu.join_tolerance(tol, jtu.tolerance(dtype)) - self._CheckAgainstNumpy( - np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, atol=tol, - rtol=tol) - def testIssue967(self): self.assertRaises(TypeError, lambda: jnp.zeros(1.5)) @@ -6229,10 +5243,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): v = np.arange(12, dtype=np.int32).reshape(3, 4) self.assertEqual(jnp.asarray(v).tolist(), v.tolist()) - def testReductionWithRepeatedAxisError(self): - with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"): - jnp.sum(jnp.arange(3), (0, 0)) - def testArangeConcretizationError(self): msg = r"It arose in jax.numpy.arange argument `{}`".format with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):