Split parts of lax_numpy_test.py into separate test files.

Why? The main test file is getting too big and this hinders iteration on individual tests

PiperOrigin-RevId: 478130215
This commit is contained in:
Jake VanderPlas 2022-09-30 19:37:42 -07:00 committed by jax authors
parent 849f837b6a
commit 439217644a
4 changed files with 1245 additions and 993 deletions

View File

@ -320,9 +320,31 @@ jax_test(
srcs = ["lax_numpy_test.py"],
pjrt_c_api_bypass = True,
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 20,
"cpu": 20,
"gpu": 20,
"tpu": 10,
},
)
jax_test(
name = "lax_numpy_operators_test",
srcs = ["lax_numpy_operators_test.py"],
pjrt_c_api_bypass = True,
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 5,
},
)
jax_test(
name = "lax_numpy_reducers_test",
srcs = ["lax_numpy_reducers_test.py"],
pjrt_c_api_bypass = True,
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 5,
},
)

View File

@ -0,0 +1,636 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
from functools import partial
import itertools
import operator
from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
import jax.ops
from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
one_dim_array_shapes = [(1,), (6,), (12,)]
empty_array_shapes = [(0,), (0, 4), (3, 0),]
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
array_shapes = nonempty_array_shapes + empty_array_shapes
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
nonempty_shapes = scalar_shapes + nonempty_array_shapes
all_shapes = scalar_shapes + array_shapes
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex
int_dtypes = jtu.dtypes.all_integer
unsigned_dtypes = jtu.dtypes.all_unsigned
bool_dtypes = jtu.dtypes.boolean
default_dtypes = float_dtypes + int_dtypes
inexact_dtypes = float_dtypes + complex_dtypes
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
all_dtypes = number_dtypes + bool_dtypes
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
# uint64 is problematic because with any uint type it promotes to float:
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
def _indexer_with_default_outputs(indexer, use_defaults=True):
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
class Indexer:
@partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults)
def __getitem__(self, *args):
return indexer.__getitem__(*args)
return Indexer()
def _valid_dtypes_for_shape(shape, dtypes):
# Not all (shape, dtype) pairs are valid. In particular, Python scalars only
# have one type in each category (float, bool, etc.)
if shape is jtu.PYTHON_SCALAR_SHAPE:
return [t for t in dtypes if t in python_scalar_dtypes]
return dtypes
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name=None, check_dtypes=True,
tolerance=None, inexact=False, kwargs=None):
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name, check_dtypes, tolerance, inexact, kwargs)
JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("abs", 1, all_dtypes,
all_shapes, jtu.rand_default, ["rev"]),
op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("ceil", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("float_power", 2, inexact_dtypes, all_shapes,
partial(jtu.rand_default, scale=1), ["rev"],
tolerance={jnp.bfloat16: 1e-2, np.float32: 1e-3,
np.float64: 1e-12, np.complex64: 2e-4,
np.complex128: 1e-12}, check_dtypes=False),
op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("floor", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False),
op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False),
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("maximum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("minimum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"]),
op_record("trunc", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("trunc", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, [], check_dtypes=False),
op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("tan", 1, number_dtypes, all_shapes,
partial(jtu.rand_uniform, low=-1.5, high=1.5), ["rev"],
inexact=True),
op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
# ~float32 precision.
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
tolerance={np.float64: 1e-7, np.complex128: 1e-7},
inexact=True),
op_record("arcsin", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arccos", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arctan", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True, tolerance={np.complex64: 2E-4, np.complex128: 2E-14}),
op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True, tolerance={np.complex64: 2E-2, np.complex128: 2E-12}),
op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True, tolerance={np.float64: 1e-9}),
]
JAX_COMPOUND_OP_RECORDS = [
# angle has inconsistent 32/64-bit return types across numpy versions.
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False, inexact=True),
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False, inexact=True, test_name="angle_deg", kwargs={'deg': True}),
op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_some_inf, ["rev"],
inexact=True),
op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"],
inexact=True),
op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
tolerance={jnp.bfloat16: 4e-2, np.float16: 1e-2}, inexact=True),
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
# precision.
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
test_name="expm1_large", tolerance={np.float64: 1e-8}, inexact=True),
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive,
[], tolerance={np.float64: 1e-8}, inexact=True),
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("fix", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("floor_divide", 2, default_dtypes + unsigned_dtypes,
all_shapes, jtu.rand_nonzero, ["rev"]),
op_record("fmin", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("fmax", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("isfinite", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isinf", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isnan", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
test_name="log1p_large", tolerance={np.float64: 1e-12},
inexact=True),
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [],
tolerance={np.float64: 1e-12}, inexact=True),
op_record("logaddexp", 2, float_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"],
tolerance={np.float64: 1e-12}, inexact=True),
op_record("logaddexp2", 2, float_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"],
tolerance={np.float16: 1e-2, np.float64: 2e-14}, inexact=True),
op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes,
jtu.rand_default, [], check_dtypes=False,
tolerance={dtypes.bfloat16: 4e-2, np.float16: 1e-2,
np.float64: 1e-12}),
op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
tolerance={np.complex128: 1e-14}, check_dtypes=False),
op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={np.float16: 1e-2}),
op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("modf", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("modf", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("rint", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan,
[]),
op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
# numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately.
op_record("copysign", 2, default_dtypes + unsigned_dtypes,
all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False),
op_record("sinc", 1, [t for t in number_dtypes if t != jnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"],
tolerance={np.complex64: 1e-5}, inexact=True,
check_dtypes=False),
op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"],
check_dtypes=False),
op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero,
["rev"], inexact=True),
op_record("ediff1d", 3, [np.int32], all_shapes, jtu.rand_default, [], check_dtypes=False),
# TODO(phawkins): np.unwrap does not correctly promote its default period
# argument under NumPy 1.21 for bfloat16 inputs. It works fine if we
# explicitly pass a bfloat16 value that does not need promition. We should
# probably add a custom test harness for unwrap that tests the period
# argument anyway.
op_record("unwrap", 1, [t for t in float_dtypes if t != dtypes.bfloat16],
nonempty_nonscalar_array_shapes,
jtu.rand_default, ["rev"],
# numpy.unwrap always returns float64
check_dtypes=False,
# numpy cumsum is inaccurate, see issue #3517
tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1}),
op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16],
all_shapes, jtu.rand_small_positive, []),
op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
]
JAX_BITWISE_OP_RECORDS = [
op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
op_record("invert", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
]
JAX_OPERATOR_OVERLOADS = [
op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__le__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
tolerance={np.float32: 2e-4, np.complex64: 2e-4, np.complex128: 1e-14}),
op_record("__mod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={np.float16: 1e-1}),
op_record("__floordiv__", 2, default_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [],
inexact=True),
op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
# TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2
op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []),
# TODO(mattjj): investigate these failures
# op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
# op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("__lshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []),
op_record("__rshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []),
]
JAX_RIGHT_OPERATOR_OVERLOADS = [
op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
tolerance={np.float32: 2e-4, np.complex64: 1e-3}),
op_record("__rmod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={np.float16: 1e-1}),
op_record("__rfloordiv__", 2, default_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [],
inexact=True),
# op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
# op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("__rlshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []),
op_record("__rrshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), [])
]
class _OverrideEverything:
pass
for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
if rec.nargs == 2:
setattr(_OverrideEverything, rec.name, lambda self, other: self)
class _OverrideNothing:
pass
for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
if rec.nargs == 2:
setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented)
def _dtypes_are_compatible_for_bitwise_ops(args):
if len(args) <= 1:
return True
is_signed = lambda dtype: jnp.issubdtype(dtype, np.signedinteger)
width = lambda dtype: jnp.iinfo(dtype).bits
x, y = args
if width(x) > width(y):
x, y = y, x
# The following condition seems a little ad hoc, but seems to capture what
# numpy actually implements.
return (
is_signed(x) == is_signed(y)
or (width(x) == 32 and width(y) == 32)
or (width(x) == 32 and width(y) == 64 and is_signed(y)))
def _shapes_are_broadcast_compatible(shapes):
try:
lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes))
except ValueError:
return False
else:
return True
def _shapes_are_equal_length(shapes):
return all(len(shape) == len(shapes[0]) for shape in shapes[1:])
def _promote_like_jnp(fun, inexact=False):
"""Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`.
jnp and np have different type promotion semantics; this decorator allows
tests make an np reference implementation act more like an jnp
implementation.
"""
_promote = _promote_dtypes_inexact if inexact else _promote_dtypes
def wrapper(*args, **kw):
flat_args, tree = tree_util.tree_flatten(args)
args = tree_util.tree_unflatten(tree, _promote(*flat_args))
return fun(*args, **kw)
return wrapper
class JaxNumpyOperatorTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy operators."""
def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
def f():
out = [rng(shape, dtype or jnp.float_)
for shape, dtype in zip(shapes, dtypes)]
if np_arrays:
return out
return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a
for a in out]
return f
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance,
"inexact": rec.inexact, "kwargs": rec.kwargs or {}}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
JAX_COMPOUND_OP_RECORDS)))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testOp(self, np_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes,
tolerance, inexact, kwargs):
np_op = partial(np_op, **kwargs)
jnp_op = partial(jnp_op, **kwargs)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(np_op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="divide by zero.*")(np_op)
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes)
tol = functools.reduce(jtu.join_tolerance,
[tolerance, tol, jtu.default_tolerance()])
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), jnp_op,
args_maker, check_dtypes=check_dtypes, tol=tol)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes,
atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
"tol": rec.tolerance}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in JAX_OPERATOR_OVERLOADS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol):
rng = rng_factory(self.rng())
# np and jnp arrays have different type promotion rules; force the use of
# jnp arrays.
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
fun = lambda *xs: getattr(operator, name.strip('_'))(*xs)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
"op_tolerance": rec.tolerance}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in JAX_RIGHT_OPERATOR_OVERLOADS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes,
op_tolerance):
if shapes[1] is jtu.PYTHON_SCALAR_SHAPE:
raise SkipTest("scalars not implemented") # TODO(mattjj): clean up
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
fun = lambda fst, snd: getattr(snd, name)(fst)
tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype}
for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2
for othertype in [dict, list, tuple, set]))
def testOperatorOverloadErrors(self, name, othertype):
# Test that binary operators with builtin collections raise a TypeError
# and report the types in the correct order.
data = [(1, 2), (2, 3)]
arr = jnp.array(data)
other = othertype(data)
if config.jax_array:
val_str = 'Array'
else:
val_str = 'DeviceArray'
msg = f"unsupported operand type.* '{val_str}' and '{othertype.__name__}'"
with self.assertRaisesRegex(TypeError, msg):
getattr(arr, name)(other)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype}
for rec in JAX_RIGHT_OPERATOR_OVERLOADS if rec.nargs == 2
for othertype in [dict, list, tuple, set]))
def testRightOperatorOverloadErrors(self, name, othertype):
# Test that binary operators with builtin collections raise a TypeError
# and report the types in the correct order.
data = [(1, 2), (2, 3)]
arr = jnp.array(data)
other = othertype(data)
if config.jax_array:
val_str = 'Array'
else:
val_str = 'DeviceArray'
msg = f"unsupported operand type.* '{othertype.__name__}' and '{val_str}'"
with self.assertRaisesRegex(TypeError, msg):
getattr(arr, name)(other)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": rec.test_name + f"_{dtype}",
"rng_factory": rec.rng_factory,
"op_name": rec.name, "dtype": dtype}
for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2
for dtype in rec.dtypes))
def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):
rng = rng_factory(self.rng())
arg = jax.device_put(rng((), dtype))
op = getattr(operator, op_name)
other = _OverrideEverything()
assert op(other, arg) is other
assert op(arg, other) is other
other = _OverrideNothing()
if op_name == "__eq__":
assert op(other, arg) is False
assert op(arg, other) is False
elif op_name == "__ne__":
assert op(other, arg) is True
assert op(arg, other) is True
else:
with self.assertRaises(TypeError):
op(other, arg)
with self.assertRaises(TypeError):
op(arg, other)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.test_name, shapes, dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name)}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in filter(
_dtypes_are_compatible_for_bitwise_ops,
itertools.combinations_with_replacement(rec.dtypes, rec.nargs)))
for rec in JAX_BITWISE_OP_RECORDS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes):
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(op.__name__, shapes, dtypes),
"op": op, "dtypes": dtypes, "shapes": shapes}
for op in [jnp.left_shift, jnp.right_shift]
for shapes in filter(
_shapes_are_broadcast_compatible,
# TODO numpy always promotes to shift dtype for zero-dim shapes:
itertools.combinations_with_replacement(nonzerodim_shapes, 2))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes))))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testShiftOpAgainstNumpy(self, op, dtypes, shapes):
dtype, shift_dtype = dtypes
signed_mix = np.issubdtype(dtype, np.signedinteger) != \
np.issubdtype(shift_dtype, np.signedinteger)
has_32 = any(np.iinfo(d).bits == 32 for d in dtypes)
promoting_to_64 = has_32 and signed_mix
if promoting_to_64 and not config.x64_enabled:
self.skipTest("np.right_shift/left_shift promoting to int64"
"differs from jnp in 32 bit mode.")
info, shift_info = map(np.iinfo, dtypes)
x_rng = jtu.rand_int(self.rng(), low=info.min, high=info.max + 1)
# NumPy requires shifts to be non-negative and below the bit width:
shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits))
args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype))
np_op = getattr(np, op.__name__)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CompileAndCheck(op, args_maker)
self._CheckAgainstNumpy(np_op, op, args_maker)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -0,0 +1,584 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from functools import partial
import itertools
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import numpy as jnp
from jax._src import dtypes
from jax._src import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
one_dim_array_shapes = [(1,), (6,), (12,)]
empty_array_shapes = [(0,), (0, 4), (3, 0),]
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
array_shapes = nonempty_array_shapes + empty_array_shapes
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
nonempty_shapes = scalar_shapes + nonempty_array_shapes
all_shapes = scalar_shapes + array_shapes
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex
int_dtypes = jtu.dtypes.all_integer
unsigned_dtypes = jtu.dtypes.all_unsigned
bool_dtypes = jtu.dtypes.boolean
default_dtypes = float_dtypes + int_dtypes
inexact_dtypes = float_dtypes + complex_dtypes
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
all_dtypes = number_dtypes + bool_dtypes
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
def _compatible_shapes(shape):
if np.ndim(shape) == 0 or shape in scalar_shapes:
return [shape]
return (shape[n:] for n in range(len(shape) + 1))
def _get_y_shapes(y_dtype, shape, rowvar):
# Helper function for testCov.
if y_dtype is None:
return [None]
if len(shape) == 1:
return [shape]
elif rowvar or shape[0] == 1:
return [(1, shape[-1]), (2, shape[-1]), (5, shape[-1])]
return [(shape[0], 1), (shape[0], 2), (shape[0], 5)]
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name=None, check_dtypes=True,
tolerance=None, inexact=False, kwargs=None):
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name, check_dtypes, tolerance, inexact, kwargs)
JAX_REDUCER_RECORDS = [
op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
op_record("nanprod", 1, all_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("nansum", 1, number_dtypes, all_shapes, jtu.rand_some_nan, []),
]
JAX_REDUCER_INITIAL_RECORDS = [
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
op_record("max", 1, all_dtypes, all_shapes, jtu.rand_default, []),
op_record("min", 1, all_dtypes, all_shapes, jtu.rand_default, []),
]
if numpy_version >= (1, 22): # initial & where keywords added in numpy 1.22
JAX_REDUCER_INITIAL_RECORDS += [
op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("nansum", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
op_record("nanmax", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
op_record("nanmin", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
]
JAX_REDUCER_WHERE_NO_INITIAL_RECORDS = [
op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("any", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("mean", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
]
if numpy_version >= (1, 22): # where keyword added in numpy 1.22
JAX_REDUCER_WHERE_NO_INITIAL_RECORDS += [
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
]
JAX_REDUCER_NO_DTYPE_RECORDS = [
op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("nanmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
op_record("nanmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
op_record("nanvar", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
op_record("nanstd", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
]
JAX_REDUCER_PROMOTE_INT_RECORDS = [
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
]
class JaxNumpyReducerTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy reduction operations."""
def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
def f():
out = [rng(shape, dtype or jnp.float_)
for shape, dtype in zip(shapes, dtypes)]
if np_arrays:
return out
return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a
for a in out]
return f
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis,
"None" if out_dtype is None else np.dtype(out_dtype).name, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_RECORDS))
def testReducer(self, np_op, jnp_op, rng_factory, shape, dtype, out_dtype,
axis, keepdims, inexact):
rng = rng_factory(self.rng())
@jtu.ignore_warning(category=np.ComplexWarning)
@jtu.ignore_warning(category=RuntimeWarning,
message="mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered.*")
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32)
t = out_dtype if out_dtype != jnp.bfloat16 else np.float32
return np_op(x_cast, axis, dtype=t, keepdims=keepdims)
jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol_spec = {np.float16: 1e-2, np.int32: 1E-3, np.float32: 1e-3,
np.complex64: 1e-3, np.float64: 1e-5, np.complex128: 1e-5}
tol = jtu.tolerance(dtype, tol_spec)
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_NO_DTYPE_RECORDS))
def testReducerNoDtype(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, inexact):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="All-NaN (slice|axis) encountered.*")
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
args_maker = lambda: [rng(shape, dtype)]
tol = {np.float16: 0.002}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for initial in [0, 1] for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_INITIAL_RECORDS))
def testReducerInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, initial, inexact):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol = {jnp.bfloat16: 3E-2}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_promote_integers={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, promote_integers),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact,
"promote_integers": promote_integers}
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for initial in [0, 1] for keepdims in [False, True]
for promote_integers in [True, False]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_PROMOTE_INT_RECORDS))
def testReducerPromoteInt(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, initial, inexact, promote_integers):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
print(f"res.dtype = {res.dtype}")
if not promote_integers and dtypes.issubdtype(res.dtype, np.integer):
res = res.astype(dtypes.to_numeric_dtype(x.dtype))
return res
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol = {jnp.bfloat16: 3E-2}
print(jnp_fun(*args_maker()))
print(np_fun(*args_maker()))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes if np.prod(shape) == 0
for dtype in rec.dtypes
for keepdims in [False, True]
for axis in range(-len(shape), len(shape)) if shape[axis] >= 1)
for rec in JAX_REDUCER_INITIAL_RECORDS))
def testReducerNoInitialZeroDims(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, inexact):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol = {jnp.bfloat16: 3E-2}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial,
jtu.format_shape_dtype_string(whereshape, bool)),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for whereshape in _compatible_shapes(shape)
for axis in list(range(-len(shape), len(shape))) + [None]
for initial in [0, 1] for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_INITIAL_RECORDS))
def testReducerWhere(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, initial, inexact, whereshape):
if (shape in [()] + scalar_shapes and
dtype in [jnp.int16, jnp.uint16] and
jnp_op in [jnp.min, jnp.max]):
self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.")
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims,
jtu.format_shape_dtype_string(whereshape, bool)),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for whereshape in _compatible_shapes(shape)
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS))
def testReducerWhereNoInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, inexact, whereshape):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="Mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="invalid value encountered.*")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, where=where)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testReductionOfOutOfBoundsAxis(self): # Issue 888
x = jnp.ones((3, 4))
self.assertRaises(ValueError, lambda: jnp.sum(x, axis=2))
def testReductionWithRepeatedAxisError(self):
with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"):
jnp.sum(jnp.arange(3), (0, 0))
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
.format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims),
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
"ddof": ddof, "keepdims": keepdims}
for shape in [(5,), (10, 5)]
for dtype in all_dtypes
for out_dtype in inexact_dtypes
for axis in [None, 0, -1]
for ddof in [0, 1, 2]
for keepdims in [False, True]))
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
# Numpy fails with bfloat16 inputs
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
axis=axis, ddof=ddof, keepdims=keepdims)
return out.astype(out_dtype)
jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
np.float64: 1e-3, np.complex128: 1e-6})
if (jnp.issubdtype(dtype, jnp.complexfloating) and
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
else:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
atol=tol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
.format(shape, dtype, out_dtype, axis, ddof, keepdims),
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
"ddof": ddof, "keepdims": keepdims}
for shape in [(5,), (10, 5)]
for dtype in all_dtypes
for out_dtype in inexact_dtypes
for axis in [None, 0, -1]
for ddof in [0, 1, 2]
for keepdims in [False, True]))
def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
rng = jtu.rand_some_nan(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
# Numpy fails with bfloat16 inputs
out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
axis=axis, ddof=ddof, keepdims=keepdims)
return out.astype(out_dtype)
jnp_fun = partial(jnp.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
np.float64: 1e-3, np.complex128: 1e-6})
if (jnp.issubdtype(dtype, jnp.complexfloating) and
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
else:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
atol=tol)
def testNanStdGrad(self):
# Regression test for https://github.com/google/jax/issues/8128
x = jnp.arange(5.0).at[0].set(jnp.nan)
y = jax.grad(jnp.nanvar)(x)
self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75]))
z = jax.grad(jnp.nanstd)(x)
self.assertEqual(jnp.isnan(z).sum(), 0)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format(
shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights),
"shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype,"rowvar": rowvar, "ddof": ddof,
"bias": bias, "fweights": fweights, "aweights": aweights}
for shape in [(5,), (10, 5), (5, 10)]
for dtype in all_dtypes
for y_dtype in [None, dtype]
for rowvar in [True, False]
for y_shape in _get_y_shapes(y_dtype, shape, rowvar)
for bias in [True, False]
for ddof in [None, 2, 3]
for fweights in [True, False]
for aweights in [True, False]))
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights):
rng = jtu.rand_default(self.rng())
wrng = jtu.rand_positive(self.rng())
wdtype = np.real(dtype(0)).dtype
wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1]
args_maker = lambda: [rng(shape, dtype),
rng(y_shape, y_dtype) if y_dtype else None,
wrng(wshape, int) if fweights else None,
wrng(wshape, wdtype) if aweights else None]
kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias)
np_fun = lambda m, y, f, a: np.cov(m, y, fweights=f, aweights=a, **kwargs)
jnp_fun = lambda m, y, f, a: jnp.cov(m, y, fweights=f, aweights=a, **kwargs)
tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5,
np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13}
tol = 7e-2 if jtu.device_under_test() == "tpu" else tol
tol = jtu.join_tolerance(tol, jtu.tolerance(dtype))
self._CheckAgainstNumpy(
np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

File diff suppressed because it is too large Load Diff