rocm_jax/tests/lax_numpy_test.py
Peter Hawkins a52f07a21b Add an optional mode= argument to jnp.take_along_axis.
This allows users of jnp.take_along_axis to override the out-of-bounds indexing behavior.
Default to "clip", which for the forward computation is identical to the current behavior. In a future change, we will change this to "fill".
2022-04-19 16:07:00 -04:00

6401 lines
284 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
from functools import partial
import inspect
import io
import itertools
import operator
from typing import cast, Iterator, Optional, List, Tuple
import unittest
from unittest import SkipTest
import warnings
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
try:
import numpy_dispatch
except ImportError:
numpy_dispatch = None
import jax
import jax.ops
from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax.test_util import check_grads
from jax._src import device_array
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
from jax._src.util import prod, safe_zip
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
one_dim_array_shapes = [(1,), (6,), (12,)]
empty_array_shapes = [(0,), (0, 4), (3, 0),]
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
array_shapes = nonempty_array_shapes + empty_array_shapes
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
nonempty_shapes = scalar_shapes + nonempty_array_shapes
all_shapes = scalar_shapes + array_shapes
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex
int_dtypes = jtu.dtypes.all_integer
unsigned_dtypes = jtu.dtypes.all_unsigned
bool_dtypes = jtu.dtypes.boolean
default_dtypes = float_dtypes + int_dtypes
inexact_dtypes = float_dtypes + complex_dtypes
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
all_dtypes = number_dtypes + bool_dtypes
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
# uint64 is problematic because with any uint type it promotes to float:
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
def _indexer_with_default_outputs(indexer, use_defaults=True):
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
class Indexer:
@partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults)
def __getitem__(self, *args):
return indexer.__getitem__(*args)
return Indexer()
def _valid_dtypes_for_shape(shape, dtypes):
# Not all (shape, dtype) pairs are valid. In particular, Python scalars only
# have one type in each category (float, bool, etc.)
if shape is jtu.PYTHON_SCALAR_SHAPE:
return [t for t in dtypes if t in python_scalar_dtypes]
return dtypes
def _shape_and_dtypes(shapes, dtypes):
for shape in shapes:
for dtype in _valid_dtypes_for_shape(shape, dtypes):
yield (shape, dtype)
def _compatible_shapes(shape):
if shape in scalar_shapes or np.ndim(shape) == 0:
return [shape]
return (shape[n:] for n in range(len(shape) + 1))
def _get_y_shapes(y_dtype, shape, rowvar):
# Helper function for testCov.
if y_dtype is None:
return [None]
if len(shape) == 1:
return [shape]
elif rowvar or shape[0] == 1:
return [(1, shape[-1]), (2, shape[-1]), (5, shape[-1])]
return [(shape[0], 1), (shape[0], 2), (shape[0], 5)]
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name=None, check_dtypes=True,
tolerance=None, inexact=False, kwargs=None):
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name, check_dtypes, tolerance, inexact, kwargs)
JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("abs", 1, all_dtypes,
all_shapes, jtu.rand_default, ["rev"]),
op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("ceil", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("float_power", 2, inexact_dtypes, all_shapes,
partial(jtu.rand_default, scale=1), ["rev"],
tolerance={jnp.bfloat16: 1e-2, np.float32: 1e-3,
np.float64: 1e-12, np.complex64: 2e-4,
np.complex128: 1e-12}, check_dtypes=False),
op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("floor", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False),
op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False),
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("maximum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("minimum", 2, all_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"]),
op_record("trunc", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("trunc", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, [], check_dtypes=False),
op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("tan", 1, number_dtypes, all_shapes,
partial(jtu.rand_uniform, low=-1.5, high=1.5), ["rev"],
inexact=True),
op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True),
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
# ~float32 precision.
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
tolerance={np.float64: 1e-7, np.complex128: 1e-7},
inexact=True),
op_record("arcsin", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arccos", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arctan", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True),
op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True, tolerance={np.complex64: 2E-4, np.complex128: 2E-14}),
op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True, tolerance={np.complex64: 2E-2, np.complex128: 2E-12}),
op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True, tolerance={np.float64: 1e-9}),
]
JAX_COMPOUND_OP_RECORDS = [
# angle has inconsistent 32/64-bit return types across numpy versions.
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False, inexact=True),
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False, inexact=True, test_name="angle_deg", kwargs={'deg': True}),
op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_some_inf, ["rev"],
inexact=True),
op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"],
inexact=True),
op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
tolerance={jnp.bfloat16: 4e-2, np.float16: 1e-2}, inexact=True),
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
# precision.
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
test_name="expm1_large", tolerance={np.float64: 1e-8}, inexact=True),
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive,
[], tolerance={np.float64: 1e-8}, inexact=True),
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("fix", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("floor_divide", 2, default_dtypes + unsigned_dtypes,
all_shapes, jtu.rand_nonzero, ["rev"]),
op_record("fmin", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("fmax", 2, number_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("isfinite", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isinf", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isnan", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
test_name="log1p_large", tolerance={np.float64: 1e-12},
inexact=True),
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [],
tolerance={np.float64: 1e-12}, inexact=True),
op_record("logaddexp", 2, float_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"],
tolerance={np.float64: 1e-12}, inexact=True),
op_record("logaddexp2", 2, float_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"],
tolerance={np.float16: 1e-2, np.float64: 2e-14}, inexact=True),
op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes,
jtu.rand_default, [], check_dtypes=False,
tolerance={dtypes.bfloat16: 4e-2, np.float16: 1e-2,
np.float64: 1e-12}),
op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
tolerance={np.complex128: 1e-14}, check_dtypes=False),
op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={np.float16: 1e-2}),
op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("modf", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("modf", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("rint", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan,
[]),
op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
# numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately.
op_record("copysign", 2, default_dtypes + unsigned_dtypes,
all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False),
op_record("sinc", 1, [t for t in number_dtypes if t != jnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"],
tolerance={np.complex64: 1e-5}, inexact=True,
check_dtypes=False),
op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
inexact=True),
op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"],
check_dtypes=False),
op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero,
["rev"], inexact=True),
op_record("ediff1d", 3, [np.int32], all_shapes, jtu.rand_default, [], check_dtypes=False),
# TODO(phawkins): np.unwrap does not correctly promote its default period
# argument under NumPy 1.21 for bfloat16 inputs. It works fine if we
# explicitly pass a bfloat16 value that does not need promition. We should
# probably add a custom test harness for unwrap that tests the period
# argument anyway.
op_record("unwrap", 1, [t for t in float_dtypes if t != dtypes.bfloat16],
nonempty_nonscalar_array_shapes,
jtu.rand_default, ["rev"],
# numpy.unwrap always returns float64
check_dtypes=False,
# numpy cumsum is inaccurate, see issue #3517
tolerance={dtypes.bfloat16: 1e-1, np.float16: 1e-1}),
op_record("isclose", 2, [t for t in all_dtypes if t != jnp.bfloat16],
all_shapes, jtu.rand_small_positive, []),
op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
]
JAX_BITWISE_OP_RECORDS = [
op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("invert", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
]
JAX_REDUCER_RECORDS = [
op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
op_record("nanprod", 1, all_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("nansum", 1, number_dtypes, all_shapes, jtu.rand_some_nan, []),
]
JAX_REDUCER_INITIAL_RECORDS = [
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
op_record("max", 1, all_dtypes, all_shapes, jtu.rand_default, []),
op_record("min", 1, all_dtypes, all_shapes, jtu.rand_default, []),
]
if numpy_version >= (1, 22): # initial & where keywords added in numpy 1.22
JAX_REDUCER_INITIAL_RECORDS += [
op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("nansum", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
op_record("nanmax", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
op_record("nanmin", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
]
JAX_REDUCER_WHERE_NO_INITIAL_RECORDS = [
op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("any", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("mean", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
]
if numpy_version >= (1, 22): # where keyword added in numpy 1.22
JAX_REDUCER_WHERE_NO_INITIAL_RECORDS += [
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
]
JAX_REDUCER_NO_DTYPE_RECORDS = [
op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("nanmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
op_record("nanmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
op_record("nanvar", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
op_record("nanstd", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan,
[], inexact=True),
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
]
JAX_ARGMINMAX_RECORDS = [
op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
op_record("nanargmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
]
JAX_OPERATOR_OVERLOADS = [
op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__le__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
tolerance={np.float32: 2e-4, np.complex64: 2e-4, np.complex128: 1e-14}),
op_record("__mod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={np.float16: 1e-1}),
op_record("__floordiv__", 2, default_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [],
inexact=True),
op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
# TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2
op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []),
# TODO(mattjj): investigate these failures
# op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
# op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("__lshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []),
op_record("__rshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []),
]
JAX_RIGHT_OPERATOR_OVERLOADS = [
op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
tolerance={np.float32: 2e-4, np.complex64: 1e-3}),
op_record("__rmod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={np.float16: 1e-1}),
op_record("__rfloordiv__", 2, default_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [],
inexact=True),
# op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
# op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("__rlshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []),
op_record("__rrshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), [])
]
class _OverrideEverything(object):
pass
for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
if rec.nargs == 2:
setattr(_OverrideEverything, rec.name, lambda self, other: self)
class _OverrideNothing(object):
pass
for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS:
if rec.nargs == 2:
setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented)
def _dtypes_are_compatible_for_bitwise_ops(args):
if len(args) <= 1:
return True
is_signed = lambda dtype: jnp.issubdtype(dtype, np.signedinteger)
width = lambda dtype: jnp.iinfo(dtype).bits
x, y = args
if width(x) > width(y):
x, y = y, x
# The following condition seems a little ad hoc, but seems to capture what
# numpy actually implements.
return (
is_signed(x) == is_signed(y)
or (width(x) == 32 and width(y) == 32)
or (width(x) == 32 and width(y) == 64 and is_signed(y)))
def _shapes_are_broadcast_compatible(shapes):
try:
lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes))
except ValueError:
return False
else:
return True
def _shapes_are_equal_length(shapes):
return all(len(shape) == len(shapes[0]) for shape in shapes[1:])
def _promote_like_jnp(fun, inexact=False):
"""Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`.
jnp and np have different type promotion semantics; this decorator allows
tests make an np reference implementation act more like an jnp
implementation.
"""
_promote = _promote_dtypes_inexact if inexact else _promote_dtypes
def wrapper(*args, **kw):
flat_args, tree = tree_util.tree_flatten(args)
args = tree_util.tree_unflatten(tree, _promote(*flat_args))
return fun(*args, **kw)
return wrapper
class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""
def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
def f():
out = [rng(shape, dtype or jnp.float_)
for shape, dtype in zip(shapes, dtypes)]
if np_arrays:
return out
return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a
for a in out]
return f
def testNotImplemented(self):
for name in jnp._NOT_IMPLEMENTED:
func = getattr(jnp, name)
with self.assertRaises(NotImplementedError):
func()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_allow_picke={}".format(dtype, allow_pickle),
"dtype": dtype, "allow_pickle": allow_pickle}
for dtype in float_dtypes + [object]
for allow_pickle in [True, False]))
def testLoad(self, dtype, allow_pickle):
if dtype == object and not allow_pickle:
self.skipTest("dtype=object requires allow_pickle=True")
rng = jtu.rand_default(self.rng())
arr = rng((10), dtype)
with io.BytesIO() as f:
jnp.save(f, arr)
f.seek(0)
arr_out = jnp.load(f, allow_pickle=allow_pickle)
self.assertArraysEqual(arr, arr_out)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance,
"inexact": rec.inexact, "kwargs": rec.kwargs or {}}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
JAX_COMPOUND_OP_RECORDS)))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testOp(self, np_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes,
tolerance, inexact, kwargs):
np_op = partial(np_op, **kwargs)
jnp_op = partial(jnp_op, **kwargs)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(np_op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="divide by zero.*")(np_op)
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes)
tol = functools.reduce(jtu.join_tolerance,
[tolerance, tol, jtu.default_tolerance()])
self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), jnp_op,
args_maker, check_dtypes=check_dtypes, tol=tol)
self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes,
atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
"tol": rec.tolerance}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in JAX_OPERATOR_OVERLOADS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol):
rng = rng_factory(self.rng())
# np and jnp arrays have different type promotion rules; force the use of
# jnp arrays.
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
fun = lambda *xs: getattr(operator, name.strip('_'))(*xs)
self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
"op_tolerance": rec.tolerance}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in JAX_RIGHT_OPERATOR_OVERLOADS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes,
op_tolerance):
if shapes[1] is jtu.PYTHON_SCALAR_SHAPE:
raise SkipTest("scalars not implemented") # TODO(mattjj): clean up
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
fun = lambda fst, snd: getattr(snd, name)(fst)
tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes)
self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": rec.test_name + "_{}".format(dtype),
"rng_factory": rec.rng_factory,
"op_name": rec.name, "dtype": dtype}
for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2
for dtype in rec.dtypes))
def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):
rng = rng_factory(self.rng())
arg = jax.device_put(rng((), dtype))
op = getattr(operator, op_name)
other = _OverrideEverything()
assert op(other, arg) is other
assert op(arg, other) is other
other = _OverrideNothing()
if op_name == "__eq__":
assert op(other, arg) is False
assert op(arg, other) is False
elif op_name == "__ne__":
assert op(other, arg) is True
assert op(arg, other) is True
else:
with self.assertRaises(TypeError):
op(other, arg)
with self.assertRaises(TypeError):
op(arg, other)
def testArrayEqualExamples(self):
# examples from the array_equal() docstring.
self.assertTrue(jnp.array_equal([1, 2], [1, 2]))
self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2])))
self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3]))
self.assertFalse(jnp.array_equal([1, 2], [1, 4]))
a = np.array([1, np.nan])
self.assertFalse(jnp.array_equal(a, a))
self.assertTrue(jnp.array_equal(a, a, equal_nan=True))
a = np.array([1 + 1j])
b = a.copy()
a.real = np.nan
b.imag = np.nan
self.assertTrue(jnp.array_equal(a, b, equal_nan=True))
def testArrayEquivExamples(self):
# examples from the array_equiv() docstring.
self.assertTrue(jnp.array_equiv([1, 2], [1, 2]))
self.assertFalse(jnp.array_equiv([1, 2], [1, 3]))
with jax.numpy_rank_promotion('allow'):
self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]]))
def testArrayModule(self):
if numpy_dispatch is None:
raise SkipTest('requires https://github.com/seberg/numpy-dispatch')
jnp_array = jnp.array(1.0)
np_array = np.array(1.0)
module = numpy_dispatch.get_array_module(jnp_array)
self.assertIs(module, jnp)
module = numpy_dispatch.get_array_module(jnp_array, np_array)
self.assertIs(module, jnp)
def f(x):
module = numpy_dispatch.get_array_module(x)
self.assertIs(module, jnp)
return x
jax.jit(f)(jnp_array)
jax.grad(f)(jnp_array)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.test_name, shapes, dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name)}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs))
for dtypes in filter(
_dtypes_are_compatible_for_bitwise_ops,
itertools.combinations_with_replacement(rec.dtypes, rec.nargs)))
for rec in JAX_BITWISE_OP_RECORDS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testBitwiseOp(self, np_op, jnp_op, rng_factory, shapes, dtypes):
rng = rng_factory(self.rng())
if not config.x64_enabled and any(
jnp.iinfo(dtype).bits == 64 for dtype in dtypes):
self.skipTest("x64 types are disabled by jax_enable_x64")
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(op.__name__, shapes, dtypes),
"op": op, "dtypes": dtypes, "shapes": shapes}
for op in [jnp.left_shift, jnp.right_shift]
for shapes in filter(
_shapes_are_broadcast_compatible,
# TODO numpy always promotes to shift dtype for zero-dim shapes:
itertools.combinations_with_replacement(nonzerodim_shapes, 2))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes))))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testShiftOpAgainstNumpy(self, op, dtypes, shapes):
dtype, shift_dtype = dtypes
signed_mix = np.issubdtype(dtype, np.signedinteger) != \
np.issubdtype(shift_dtype, np.signedinteger)
has_32 = any(np.iinfo(d).bits == 32 for d in dtypes)
promoting_to_64 = has_32 and signed_mix
if promoting_to_64 and not config.x64_enabled:
self.skipTest("np.right_shift/left_shift promoting to int64"
"differs from jnp in 32 bit mode.")
info, shift_info = map(np.iinfo, dtypes)
x_rng = jtu.rand_int(self.rng(), low=info.min, high=info.max + 1)
# NumPy requires shifts to be non-negative and below the bit width:
shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits))
args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype))
self._CompileAndCheck(op, args_maker)
np_op = getattr(np, op.__name__)
self._CheckAgainstNumpy(np_op, op, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis,
"None" if out_dtype is None else np.dtype(out_dtype).name, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True])
for rec in JAX_REDUCER_RECORDS))
def testReducer(self, np_op, jnp_op, rng_factory, shape, dtype, out_dtype,
axis, keepdims, inexact):
rng = rng_factory(self.rng())
@jtu.ignore_warning(category=np.ComplexWarning)
@jtu.ignore_warning(category=RuntimeWarning,
message="mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered.*")
def np_fun(x):
x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32)
t = out_dtype if out_dtype != jnp.bfloat16 else np.float32
return np_op(x_cast, axis, dtype=t, keepdims=keepdims)
np_fun = _promote_like_jnp(np_fun, inexact)
jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol_spec = {np.float16: 1e-2, np.int32: 1E-3, np.float32: 1e-3,
np.complex64: 1e-3, np.float64: 1e-5, np.complex128: 1e-5}
tol = jtu.tolerance(dtype, tol_spec)
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True])
for rec in JAX_REDUCER_NO_DTYPE_RECORDS))
def testReducerNoDtype(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, inexact):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="All-NaN (slice|axis) encountered.*")
def np_fun(x):
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
args_maker = lambda: [rng(shape, dtype)]
tol = {np.float16: 0.002}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for initial in [0, 1] for keepdims in [False, True])
for rec in JAX_REDUCER_INITIAL_RECORDS))
def testReducerInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, initial, inexact):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
def np_fun(x):
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol = {jnp.bfloat16: 3E-2}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial,
jtu.format_shape_dtype_string(whereshape, bool)),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for whereshape in _compatible_shapes(shape)
for axis in list(range(-len(shape), len(shape))) + [None]
for initial in [0, 1] for keepdims in [False, True])
for rec in JAX_REDUCER_INITIAL_RECORDS))
def testReducerWhere(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, initial, inexact, whereshape):
if (shape in [()] + scalar_shapes and
dtype in [jnp.int16, jnp.uint16] and
jnp_op in [jnp.min, jnp.max]):
self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.")
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
def np_fun(x):
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 20), "where parameter not supported in older numpy")
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims,
jtu.format_shape_dtype_string(whereshape, bool)),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for whereshape in _compatible_shapes(shape)
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True])
for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS))
def testReducerWhereNoInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, inexact, whereshape):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="Mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="invalid value encountered in true_divide*")
def np_fun(x):
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, where=where)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for shape in all_shapes for dtype in all_dtypes
for axis in list(range(-len(shape), len(shape))) + [None]))
def testCountNonzero(self, shape, dtype, axis):
rng = jtu.rand_some_zero(self.rng())
np_fun = lambda x: np.count_nonzero(x, axis)
jnp_fun = lambda x: jnp.count_nonzero(x, axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes for dtype in all_dtypes))
def testNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = lambda x: np.nonzero(x)
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np_fun)
jnp_fun = lambda x: jnp.nonzero(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_size={}_fill_value={}".format(
jtu.format_shape_dtype_string(shape, dtype), size, fill_value),
"shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value}
for shape in nonempty_array_shapes
for dtype in all_dtypes
for fill_value in [None, -1, shape or (1,)]
for size in [1, 5, 10]))
def testNonzeroSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.nonzero(x)
if size <= len(result[0]):
return tuple(arg[:size] for arg in result)
else:
fillvals = fill_value if np.ndim(fill_value) else len(result) * [fill_value or 0]
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
for fval, arg in safe_zip(fillvals, result))
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes for dtype in all_dtypes))
def testFlatNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np.flatnonzero)
jnp_fun = jnp.flatnonzero
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
# JIT compilation requires specifying the size statically:
jnp_fun = lambda x: jnp.flatnonzero(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_size={}_fill_value={}".format(
jtu.format_shape_dtype_string(shape, dtype), size, fill_value),
"shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value}
for shape in nonempty_array_shapes
for dtype in all_dtypes
for fill_value in [None, -1, 10, (-1,), (10,)]
for size in [1, 5, 10]))
def testFlatNonzeroSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.flatnonzero(x)
if size <= len(result):
return result[:size]
else:
fill_val = fill_value or 0
return np.concatenate([result, np.full(size - len(result), fill_val, result.dtype)])
jnp_fun = lambda x: jnp.flatnonzero(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes for dtype in all_dtypes))
def testArgWhere(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np.argwhere)
jnp_fun = jnp.argwhere
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
# JIT compilation requires specifying a size statically. Full test of this
# behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_size={}_fill_value={}".format(
jtu.format_shape_dtype_string(shape, dtype), size, fill_value),
"shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value}
for shape in nonempty_array_shapes
for dtype in all_dtypes
for fill_value in [None, -1, shape or (1,)]
for size in [1, 5, 10]))
def testArgWhereSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.argwhere(x)
if size <= len(result):
return result[:size]
else:
fillvals = fill_value if np.ndim(fill_value) else result.shape[-1] * [fill_value or 0]
return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
for fval, arg in safe_zip(fillvals, result.T)]).T
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"axis": axis, "keepdims": keepdims}
for rec in JAX_ARGMINMAX_RECORDS
for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes)
for axis in range(-len(shape), len(shape))
for keepdims in [True, False]))
def testArgMinMax(self, np_op, jnp_op, rng_factory, shape, dtype, axis, keepdims):
rng = rng_factory(self.rng())
if dtype == np.complex128 and jtu.device_under_test() == "gpu":
raise unittest.SkipTest("complex128 reductions not supported on GPU")
if "nan" in np_op.__name__ and dtype == jnp.bfloat16:
raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays")
if numpy_version < (1, 22) and keepdims:
raise unittest.SkipTest("NumPy < 1.22 does not support keepdims argument to argmin/argmax")
kwds = {"keepdims": True} if keepdims else {}
np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds))
jnp_fun = partial(jnp_op, axis=axis, **kwds)
args_maker = lambda: [rng(shape, dtype)]
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
except ValueError as e:
if str(e) == "All-NaN slice encountered":
self.skipTest("JAX doesn't support checking for all-NaN slices")
else:
raise
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": rec.test_name.capitalize(), "name": rec.name,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name)}
for rec in JAX_ARGMINMAX_RECORDS))
def testArgMinMaxEmpty(self, name, np_op, jnp_op):
name = name[3:] if name.startswith("nan") else name
msg = "attempt to get {} of an empty sequence".format(name)
with self.assertRaises(ValueError, msg=msg):
jnp_op(np.array([]))
with self.assertRaises(ValueError, msg=msg):
jnp_op(np.zeros((2, 0)), axis=1)
np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0))
jnp_fun = partial(jnp_op, axis=0)
args_maker = lambda: [np.zeros((2, 0))]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
axes),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"axes": axes}
for lhs_shape, rhs_shape, axes in [
[(2,), (2,), (-1, -1, -1, None)], # scalar output
[(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors
[(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors
[(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting
[(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes
[(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting
[(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors
[(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting
[(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing
[(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before
]
for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2)))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
axisa, axisb, axisc, axis = axes
jnp_fun = lambda a, b: jnp.cross(a, b, axisa, axisb, axisc, axis)
def np_fun(a, b):
a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a
b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b
out = np.cross(a, b, axisa, axisb, axisc, axis)
return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype))
tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15}
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
jtu.tolerance(rhs_dtype, tol_spec))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
name,
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype}
for name, lhs_shape, rhs_shape in [
("matrix-scalar", (3, 3), ()),
("scalar-matrix", (), (3, 3)),
("matrix-vector", (4, 5), (5,)),
("vector-matrix", (6,), (6, 4)),
("matrix-matrix", (3, 4), (4, 5)),
("tensor-vector", (4, 3, 2), (2,)),
("vector-tensor", (2,), (3, 2, 4)),
("tensor-matrix", (4, 3, 2), (2, 5)),
("matrix-tensor", (5, 2), (3, 2, 4)),
("tensor-tensor", (2, 3, 4), (5, 4, 1))]
for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2)))
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
tol = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-14,
np.complex128: 1e-14}
if jtu.device_under_test() == "tpu":
tol[np.float16] = tol[np.float32] = tol[np.complex64] = 2e-1
def np_dot(x, y):
x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x
y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y
return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype))
self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker,
tol=tol)
self._CompileAndCheck(jnp.dot, args_maker, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
name,
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype}
for name, lhs_shape, rhs_shape in [
("vector-vector", (3,), (3,)),
("matrix-vector", (3, 3), (3,)),
("vector-matrix", (3,), (3, 3)),
("matrix-matrix", (3, 3), (3, 3)),
("vector-tensor", (3,), (5, 3, 2)),
("tensor-vector", (5, 3, 2), (2,)),
("matrix-tensor", (5, 2), (3, 2, 4)),
("tensor-matrix", (5, 2, 3), (3, 2)),
("tensor-tensor", (5, 3, 4), (5, 4, 1)),
("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2)))
def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
def np_fun(x, y):
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
return np.matmul(x, y).astype(dtype)
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
np.complex128: 1e-12}
if jtu.device_under_test() == "tpu":
tol[np.float16] = tol[np.float32] = tol[np.complex64] = 4e-2
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol)
self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
axes),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"axes": axes}
for lhs_shape, rhs_shape, axes in [
[(3,), (), 0],
[(2, 3, 4), (5, 6, 7), 0], # from issue #740
[(2, 3, 4), (3, 4, 5, 6), 2],
[(2, 3, 4), (5, 4, 3, 6), [1, 2]],
[(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
[(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
]
for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2)))
def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
jnp_fun = lambda a, b: jnp.tensordot(a, b, axes)
def np_fun(a, b):
a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32)
b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32)
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
return np.tensordot(a, b, axes).astype(dtype)
tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12,
np.complex64: 1e-3, np.complex128: 1e-12}
if jtu.device_under_test() == "tpu":
tol[np.float16] = tol[np.float32] = tol[np.complex64] = 2e-1
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
def testTensordotErrors(self):
a = self.rng().random((3, 2, 2))
b = self.rng().random((2,))
self.assertRaisesRegex(
TypeError, "Number of tensordot axes.*exceeds input ranks.*",
lambda: jnp.tensordot(a, b, axes=2))
self.assertRaisesRegex(
TypeError, "tensordot requires axes lists to have equal length.*",
lambda: jnp.tensordot(a, b, axes=([0], [0, 1])))
self.assertRaisesRegex(
TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*",
lambda: jnp.tensordot(a, b, axes=('bad', 'axes')))
self.assertRaisesRegex(
TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*",
lambda: jnp.tensordot(a, b, axes='badaxes'))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_invert={}".format(
jtu.format_shape_dtype_string(element_shape, dtype),
jtu.format_shape_dtype_string(test_shape, dtype), invert),
"element_shape": element_shape, "test_shape": test_shape,
"dtype": dtype, "invert": invert}
for element_shape in all_shapes
for test_shape in all_shapes
for dtype in default_dtypes
for invert in [True, False]))
def testIsin(self, element_shape, test_shape, dtype, invert):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert)
np_fun = lambda e, t: np.isin(e, t, invert=invert)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_invert={}".format(
jtu.format_shape_dtype_string(element_shape, dtype),
jtu.format_shape_dtype_string(test_shape, dtype), invert),
"element_shape": element_shape, "test_shape": test_shape,
"dtype": dtype, "invert": invert}
for element_shape in all_shapes
for test_shape in all_shapes
for dtype in default_dtypes
for invert in [True, False]))
def testIn1d(self, element_shape, test_shape, dtype, invert):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)]
jnp_fun = lambda e, t: jnp.in1d(e, t, invert=invert)
np_fun = lambda e, t: np.in1d(e, t, invert=invert)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(
jtu.format_shape_dtype_string(shape1, dtype1),
jtu.format_shape_dtype_string(shape2, dtype2)),
"shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2}
for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16]
for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16]
for shape1 in all_shapes
for shape2 in all_shapes))
def testSetdiff1d(self, shape1, shape2, dtype1, dtype2):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_size={}_fill_value={}".format(
jtu.format_shape_dtype_string(shape1, dtype1),
jtu.format_shape_dtype_string(shape2, dtype2),
size, fill_value),
"shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2,
"size": size, "fill_value": fill_value}
for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16]
for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16]
for shape1 in all_shapes
for shape2 in all_shapes
for size in [1, 5, 10]
for fill_value in [None, -1]))
def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
def np_fun(arg1, arg2):
result = np.setdiff1d(arg1, arg2)
if size <= len(result):
return result[:size]
else:
return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0)
def jnp_fun(arg1, arg2):
return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(
jtu.format_shape_dtype_string(shape1, dtype1),
jtu.format_shape_dtype_string(shape2, dtype2)),
"shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2}
for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16]
for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16]
for shape1 in nonempty_nonscalar_array_shapes
for shape2 in nonempty_nonscalar_array_shapes))
def testUnion1d(self, shape1, shape2, dtype1, dtype2):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
def np_fun(arg1, arg2):
dtype = jnp.promote_types(arg1.dtype, arg2.dtype)
return np.union1d(arg1, arg2).astype(dtype)
self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_size={}_fill_value={}".format(
jtu.format_shape_dtype_string(shape1, dtype1),
jtu.format_shape_dtype_string(shape2, dtype2), size, fill_value),
"shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2,
"size": size, "fill_value": fill_value}
for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16]
for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16]
for shape1 in nonempty_nonscalar_array_shapes
for shape2 in nonempty_nonscalar_array_shapes
for size in [1, 5, 10]
for fill_value in [None, -1]))
def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
def np_fun(arg1, arg2):
dtype = jnp.promote_types(arg1.dtype, arg2.dtype)
result = np.union1d(arg1, arg2).astype(dtype)
fv = result.min() if fill_value is None else fill_value
if size <= len(result):
return result[:size]
else:
return np.concatenate([result, np.full(size - len(result), fv, result.dtype)])
def jnp_fun(arg1, arg2):
return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_assume_unique={}".format(
jtu.format_shape_dtype_string(shape1, dtype1),
jtu.format_shape_dtype_string(shape2, dtype2),
assume_unique),
"shape1": shape1, "dtype1": dtype1, "shape2": shape2, "dtype2": dtype2,
"assume_unique": assume_unique}
for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16]
for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16]
for shape1 in all_shapes
for shape2 in all_shapes
for assume_unique in [False, True]))
def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique)
def np_fun(ar1, ar2):
if assume_unique:
# pre-flatten the arrays to match with jax implementation
ar1 = np.ravel(ar1)
ar2 = np.ravel(ar2)
return np.setxor1d(ar1, ar2, assume_unique)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_assume_unique={}_return_indices={}".format(
jtu.format_shape_dtype_string(shape1, dtype1),
jtu.format_shape_dtype_string(shape2, dtype2),
assume_unique,
return_indices),
"shape1": shape1, "dtype1": dtype1, "shape2": shape2, "dtype2": dtype2,
"assume_unique": assume_unique, "return_indices": return_indices}
for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16]
for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16]
for shape1 in all_shapes
for shape2 in all_shapes
for assume_unique in [False, True]
for return_indices in [False, True]))
def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, return_indices):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype}
# TODO(phawkins): support integer dtypes too.
for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes)
for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes)
if len(jtu._dims_of_shape(lhs_shape)) == 0
or len(jtu._dims_of_shape(rhs_shape)) == 0
or lhs_shape[-1] == rhs_shape[-1]))
def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
def np_fun(lhs, rhs):
lhs = lhs if lhs_dtype != jnp.bfloat16 else lhs.astype(np.float32)
rhs = rhs if rhs_dtype != jnp.bfloat16 else rhs.astype(np.float32)
dtype = jnp.promote_types(lhs_dtype, rhs_dtype)
return np.inner(lhs, rhs).astype(dtype)
jnp_fun = lambda lhs, rhs: jnp.inner(lhs, rhs)
tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13,
np.complex64: 1e-5}
if jtu.device_under_test() == "tpu":
tol_spec[np.float32] = tol_spec[np.complex64] = 2e-1
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
jtu.tolerance(rhs_dtype, tol_spec))
# TODO(phawkins): there are float32/float64 disagreements for some inputs.
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_deg={}_rcond={}_full={}_w={}_cov={}".format(
jtu.format_shape_dtype_string(shape, dtype),
deg,
rcond,
full,
w,
cov),
"shape": shape, "dtype": dtype, "deg": deg,
"rcond": rcond, "full": full, "w":w, "cov":cov}
for dtype in [dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]]
for shape in [shape for shape in one_dim_array_shapes if shape != (1,)]
for deg in [1, 2, 3]
for rcond in [None, -1, 10e-3, 10e-5, 10e-10]
for full in [False, True]
for w in [False, True]
for cov in [False, True, "unscaled"]))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov):
rng = jtu.rand_default(self.rng())
tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5}
if jtu.device_under_test() == "tpu":
tol_spec[np.float32] = tol_spec[np.complex64] = 2e-1
tol = jtu.tolerance(dtype, tol_spec)
_w = lambda a: abs(a) if w else None
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)]
jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)
np_fun = jtu.ignore_warning(
message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_amin={}_amax={}".format(
jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
"shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max}
for shape in all_shapes for dtype in number_dtypes
for a_min, a_max in [(-1, None), (None, 1), (-0.9, 1),
(-np.ones(1), None),
(None, np.ones(1)),
(np.full(1, -0.9), np.ones(1))]))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testClipStaticBounds(self, shape, dtype, a_min, a_max):
if np.issubdtype(dtype, np.unsignedinteger):
a_min = None if a_min is None else abs(a_min)
a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
def testClipError(self):
with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"):
jnp.clip(jnp.zeros((3,)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_decimals={}".format(
jtu.format_shape_dtype_string(shape, dtype), decimals),
"shape": shape, "dtype": dtype, "decimals": decimals}
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)
for decimals in [0, 1, -2]))
def testRoundStaticDecimals(self, shape, dtype, decimals):
rng = jtu.rand_default(self.rng())
if jnp.issubdtype(dtype, np.integer) and decimals < 0:
self.skipTest("Integer rounding with decimals < 0 not implemented")
np_fun = lambda x: np.round(x, decimals=decimals)
jnp_fun = lambda x: jnp.round(x, decimals=decimals)
args_maker = lambda: [rng(shape, dtype)]
tol = {jnp.bfloat16: 5e-2, np.float16: 1e-2}
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
atol=tol, rtol=tol)
def testOperatorRound(self):
self.assertAllClose(round(np.float32(7.532), 1),
round(jnp.float32(7.5), 1))
self.assertAllClose(round(np.float32(1.234), 2),
round(jnp.float32(1.234), 2))
self.assertAllClose(round(np.float32(1.234)),
round(jnp.float32(1.234)), check_dtypes=False)
self.assertAllClose(round(np.float32(7.532), 1),
round(jnp.array(7.5, jnp.float32), 1))
self.assertAllClose(round(np.float32(1.234), 2),
round(jnp.array(1.234, jnp.float32), 2))
self.assertAllClose(round(np.float32(1.234)),
round(jnp.array(1.234, jnp.float32)),
check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_padwidth={}_constantvalues={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width,
constant_values),
"shape": shape, "dtype": dtype, "mode": mode,
"pad_width": pad_width, "constant_values": constant_values}
for mode, shapes in [
('constant', all_shapes),
('wrap', nonempty_shapes),
('edge', nonempty_shapes),
]
for shape, dtype in _shape_and_dtypes(shapes, all_dtypes)
for constant_values in [
# None is used for modes other than 'constant'
None,
# constant
0, 1,
# (constant,)
(0,), (2.718,),
# ((before_const, after_const),)
((0, 2),), ((-1, 3.14),),
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i / 2, -3.14 * i) for i in range(len(shape))),
]
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 0),),
# (before, after) (not in the docstring but works in numpy)
(2, 0), (0, 0),
# (pad,)
(1,), (2,),
# pad
0, 1,
]
if (pad_width != () and constant_values != () and
((mode == 'constant' and constant_values is not None) or
(mode != 'constant' and constant_values is None)))))
def testPad(self, shape, dtype, mode, pad_width, constant_values):
if np.issubdtype(dtype, np.unsignedinteger):
constant_values = tree_util.tree_map(abs, constant_values)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if constant_values is None:
np_fun = partial(np.pad, pad_width=pad_width, mode=mode)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode)
else:
np_fun = partial(np.pad, pad_width=pad_width, mode=mode,
constant_values=constant_values)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode,
constant_values=constant_values)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_pad_width={}_stat_length={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, stat_length),
"shape": shape, "dtype": dtype, "mode": mode, "pad_width": pad_width,
"stat_length": stat_length}
for mode in ['maximum', 'minimum', 'mean', 'median']
for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes)
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 0),),
# (before, after) (not in the docstring but works in numpy)
(2, 0), (0, 0),
# (pad,)
(1,), (2,),
# pad
0, 1,
]
for stat_length in [
None,
# ((before_1, after_1), ..., (before_N, after_N))
tuple(((i % 3 + 1), ((i + 1) % 3) + 1) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 2),),
# (before, after) (not in the docstring but works in numpy)
(1, 1), (3, 4),
# (pad,)
(1,), (2,),
# pad
1, 2
]
if (pad_width != () and stat_length != () and
not (dtype in bool_dtypes and mode == 'mean'))))
def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length):
if mode == 'median' and np.issubdtype(dtype, np.complexfloating):
self.skipTest("median statistic is not supported for dtype=complex.")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = partial(np.pad, pad_width=pad_width, mode=mode, stat_length=stat_length)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, stat_length=stat_length)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_pad_width={}_reflect_type={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, reflect_type),
"shape": shape, "dtype": dtype, "mode": mode, "pad_width": pad_width,
"reflect_type": reflect_type}
for mode in ['symmetric', 'reflect']
for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes)
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 3),),
# (before, after) (not in the docstring but works in numpy)
(2, 1), (1, 2),
# (pad,)
(1,), (2,), (3,),
# pad
0, 5, 7, 10
]
for reflect_type in ['even', 'odd']
if (pad_width != () and
# following types lack precision when calculating odd values
(reflect_type != 'odd' or dtype not in [np.bool_, np.float16, jnp.bfloat16]))))
def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE,
tol={np.float32: 1e-3, np.complex64: 1e-3})
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format(
jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values),
"shape": shape, "dtype": dtype, "pad_width": pad_width,
"end_values": end_values}
for shape, dtype in _shape_and_dtypes(nonempty_shapes, default_dtypes + complex_dtypes)
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 0),),
# (before, after) (not in the docstring but works in numpy)
(2, 0), (0, 0),
# (pad,)
(1,), (2,),
# pad
0, 1,
]
for end_values in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2.0, 3.14),),
# (before, after) (not in the docstring but works in numpy)
(0, 0), (-8.0, 2.0),
# (end_values,)
(1,), (2,),
# end_values
0, 1, 100, 10.0, 3.5, 4.2, -5, -3
]
if (pad_width != () and end_values != () and
# following types lack precision
dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])))
def testPadLinearRamp(self, shape, dtype, pad_width, end_values):
if numpy_version < (1, 20) and np.issubdtype(dtype, np.integer):
raise unittest.SkipTest("NumPy 1.20 changed the semantics of np.linspace")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = partial(np.pad, pad_width=pad_width, mode="linear_ramp",
end_values=end_values)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode="linear_ramp",
end_values=end_values)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)
def testPadEmpty(self):
arr = np.arange(6).reshape(2, 3)
pad_width = ((2, 3), (3, 1))
np_res = np.pad(arr, pad_width=pad_width, mode="empty")
jnp_res = jnp.pad(arr, pad_width=pad_width, mode="empty")
np.testing.assert_equal(np_res.shape, jnp_res.shape)
np.testing.assert_equal(arr, np_res[2:-3, 3:-1])
np.testing.assert_equal(arr, jnp_res[2:-3, 3:-1])
np.testing.assert_equal(np_res[2:-3, 3:-1], jnp_res[2:-3, 3:-1])
def testPadKwargs(self):
modes = {
'constant': {'constant_values': 0},
'edge': {},
'linear_ramp': {'end_values': 0},
'maximum': {'stat_length': None},
'mean': {'stat_length': None},
'median': {'stat_length': None},
'minimum': {'stat_length': None},
'reflect': {'reflect_type': 'even'},
'symmetric': {'reflect_type': 'even'},
'wrap': {},
'empty': {}
}
arr = jnp.array([1, 2, 3])
pad_width = 1
for mode in modes.keys():
allowed = modes[mode]
not_allowed = {}
for kwargs in modes.values():
if kwargs != allowed:
not_allowed.update(kwargs)
# Test if allowed keyword arguments pass
jnp.pad(arr, pad_width, mode, **allowed)
# Test if prohibited keyword arguments of other modes raise an error
match = "unsupported keyword arguments for mode '{}'".format(mode)
for key, value in not_allowed.items():
with self.assertRaisesRegex(ValueError, match):
jnp.pad(arr, pad_width, mode, **{key: value})
# Test if unsupported mode raise error.
unsupported_modes = [1, None, "foo"]
for mode in unsupported_modes:
match = "Unimplemented padding mode '{}' for np.pad.".format(mode)
with self.assertRaisesRegex(NotImplementedError, match):
jnp.pad(arr, pad_width, mode)
def testPadFunction(self):
def np_pad_with(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10)
vector[:pad_width[0]] = pad_value
vector[-pad_width[1]:] = pad_value
def jnp_pad_with(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10)
vector = vector.at[:pad_width[0]].set(pad_value)
vector = vector.at[-pad_width[1]:].set(pad_value)
return vector
arr = np.arange(6).reshape(2, 3)
np_res = np.pad(arr, 2, np_pad_with)
jnp_res = jnp.pad(arr, 2, jnp_pad_with)
np.testing.assert_equal(np_res, jnp_res)
arr = np.arange(24).reshape(2, 3, 4)
np_res = np.pad(arr, 1, np_pad_with, padder=100)
jnp_res = jnp.pad(arr, 1, jnp_pad_with, padder=100)
np.testing.assert_equal(np_res, jnp_res)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(arr.shape, arr.dtype)]
jnp_fun = partial(jnp.pad, pad_width=1, mode=jnp_pad_with)
self._CompileAndCheck(jnp_fun, args_maker)
def testPadWithNumpyPadWidth(self):
a = jnp.array([1, 2, 3, 4, 5])
f = jax.jit(
partial(
jnp.pad,
pad_width=np.asarray((2, 3)),
mode="constant",
constant_values=(4, 6)))
np.testing.assert_array_equal(
f(a),
np.pad(
a,
pad_width=np.asarray((2, 3)),
mode="constant",
constant_values=(4, 6)))
def testPadWeakType(self):
x = jnp.array(1.0)[None]
for mode in ['constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median',
'minimum', 'reflect', 'symmetric', 'wrap', 'empty']:
y = jnp.pad(x, 0, mode=mode)
self.assertTrue(dtypes.is_weakly_typed(y))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape=[{}]_reps={}".format(
jtu.format_shape_dtype_string(shape, dtype), reps),
"shape": shape, "dtype": dtype, "reps": reps}
for reps in [(), (2,), (3, 4), (2, 3, 4), (1, 0, 2)]
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)
))
def testTile(self, shape, dtype, reps):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg: np.tile(arg, reps)
jnp_fun = lambda arg: jnp.tile(arg, reps)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in all_dtypes))
def testExtract(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)]
self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_ncond={}_nfunc={}".format(
jtu.format_shape_dtype_string(shape, dtype), ncond, nfunc),
"shape": shape, "dtype": dtype, "ncond": ncond, "nfunc": nfunc}
for ncond in [1, 2, 3]
for nfunc in [ncond, ncond + 1]
for shape in all_shapes
for dtype in all_dtypes))
def testPiecewise(self, shape, dtype, ncond, nfunc):
rng = jtu.rand_default(self.rng())
rng_bool = jtu.rand_int(self.rng(), 0, 2)
funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc]
args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)])
np_fun = partial(np.piecewise, funclist=funclist)
jnp_fun = partial(jnp.piecewise, funclist=funclist)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
# This is a higher-order function, so the cache miss check will fail.
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_cache_misses=False)
def testPiecewiseRecompile(self):
def g(x):
g.num_traces += 1
return x
g.num_traces = 0
x = jnp.arange(10.0)
for i in range(5):
jnp.piecewise(x, [x < 0], [g, 0.])
self.assertEqual(g.num_traces, 1)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_perm={}_{}".format(
jtu.format_shape_dtype_string(shape, dtype), perm, arg_type),
"dtype": dtype, "shape": shape, "perm": perm, "arg_type": arg_type}
for dtype in default_dtypes
for shape in array_shapes
for arg_type in ["splat", "value"]
for perm in [None, tuple(np.random.RandomState(0).permutation(np.zeros(shape).ndim))]))
def testTransposeTuple(self, shape, dtype, perm, arg_type):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if arg_type == "value":
np_fun = lambda x: x.transpose(perm)
jnp_fun = lambda x: jnp.array(x).transpose(perm)
else:
np_fun = lambda x: x.transpose(*(perm or ()))
jnp_fun = lambda x: jnp.array(x).transpose(*(perm or ()))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_trim={}".format(
jtu.format_shape_dtype_string(a_shape, dtype), trim),
"dtype": dtype, "a_shape": a_shape, "trim": trim}
for dtype in default_dtypes
for a_shape in one_dim_array_shapes
for trim in ["f", "b", "fb"]))
def testTrimZeros(self, a_shape, dtype, trim):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(a_shape, dtype)]
np_fun = lambda arg1: np.trim_zeros(arg1, trim)
jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_rank{}".format(
jtu.format_shape_dtype_string(a_shape, dtype), rank),
"dtype": dtype, "a_shape": a_shape, "rank": rank}
for rank in (1, 2)
for dtype in default_dtypes
for a_shape in one_dim_array_shapes))
def testPoly(self, a_shape, dtype, rank):
if dtype in (np.float16, jnp.bfloat16, np.int16):
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
elif rank == 2 and jtu.device_under_test() in ("tpu", "gpu"):
self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.")
rng = jtu.rand_default(self.rng())
tol = { np.int8: 1e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 }
if jtu.device_under_test() == "tpu":
tol[np.int32] = tol[np.float32] = 1e-1
tol = jtu.tolerance(dtype, tol)
args_maker = lambda: [rng(a_shape * rank, dtype)]
self._CheckAgainstNumpy(np.poly, jnp.poly, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp.poly, args_maker, check_dtypes=True, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "a_shape={} , b_shape={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),
jtu.format_shape_dtype_string(b_shape, dtype)),
"dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape}
for dtype in default_dtypes
for a_shape in one_dim_array_shapes
for b_shape in one_dim_array_shapes))
def testPolyAdd(self, a_shape, b_shape, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1, arg2: np.polyadd(arg1, arg2)
jnp_fun = lambda arg1, arg2: jnp.polyadd(arg1, arg2)
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "a_shape={} , b_shape={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),
jtu.format_shape_dtype_string(b_shape, dtype)),
"dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape}
for dtype in default_dtypes
for a_shape in one_dim_array_shapes
for b_shape in one_dim_array_shapes))
def testPolySub(self, a_shape, b_shape, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1, arg2: np.polysub(arg1, arg2)
jnp_fun = lambda arg1, arg2: jnp.polysub(arg1, arg2)
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_order={}_k={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),
order, k),
"dtype": dtype, "a_shape": a_shape, "order" : order, "k": k}
for dtype in default_dtypes
for a_shape in one_dim_array_shapes
for order in range(5)
for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None]))
def testPolyInt(self, a_shape, order, k, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1: np.polyint(arg1, m=order, k=k)
jnp_fun = lambda arg1: jnp.polyint(arg1, m=order, k=k)
args_maker = lambda: [rng(a_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_order={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),
order),
"dtype": dtype, "a_shape": a_shape, "order" : order}
for dtype in default_dtypes
for a_shape in one_dim_array_shapes
for order in range(5)))
def testPolyDer(self, a_shape, order, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1: np.polyder(arg1, m=order)
jnp_fun = lambda arg1: jnp.polyder(arg1, m=order)
args_maker = lambda: [rng(a_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_ptype={}".format(ptype), "ptype": ptype}
for ptype in ['int', 'np.int', 'jnp.int']))
def testIntegerPower(self, ptype):
p = {'int': 2, 'np.int': np.int32(2), 'jnp.int': jnp.int32(2)}[ptype]
jaxpr = jax.make_jaxpr(partial(jnp.power, x2=p))(1)
eqns = jaxpr.jaxpr.eqns
self.assertLen(eqns, 1)
self.assertEqual(eqns[0].primitive, lax.integer_pow_p)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_y={}".format(x, y), "x": x, "y": y}
for x in [-1, 0, 1]
for y in [0, 32, 64, 128]))
def testIntegerPowerOverflow(self, x, y):
# Regression test for https://github.com/google/jax/issues/5987
args_maker = lambda: [x, y]
self._CheckAgainstNumpy(np.power, jnp.power, args_maker)
self._CompileAndCheck(jnp.power, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for shape in all_shapes
for dtype in all_dtypes
for axis in [None] + list(range(len(shape)))))
def testCompress(self, shape, dtype, axis):
rng = jtu.rand_some_zero(self.rng())
if shape in scalar_shapes or len(shape) == 0:
cond_shape = (0,)
elif axis is None:
cond_shape = (prod(shape),)
else:
cond_shape = (shape[axis],)
args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)]
np_fun = partial(np.compress, axis=axis)
jnp_fun = partial(jnp.compress, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_condition=array[{}]_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), len(condition), axis),
"shape": shape, "dtype": dtype, "condition": condition, "axis": axis}
for shape in [(2, 3)]
for dtype in int_dtypes
# condition entries beyond axis size must be zero.
for condition in [[1], [1, 0, 0, 0, 0, 0, 0]]
for axis in [None, 0, 1]))
def testCompressMismatchedShapes(self, shape, dtype, condition, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [np.array(condition), rng(shape, dtype)]
np_fun = partial(np.compress, axis=axis)
jnp_fun = partial(jnp.compress, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for shape in array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(len(shape)))))
def testCompressMethod(self, shape, dtype, axis):
rng = jtu.rand_some_zero(self.rng())
if shape in scalar_shapes or len(shape) == 0:
cond_shape = (0,)
elif axis is None:
cond_shape = (prod(shape),)
else:
cond_shape = (shape[axis],)
args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)]
np_fun = lambda condition, x: np.compress(condition, x, axis=axis)
jnp_fun = lambda condition, x: x.compress(condition, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
axis, ",".join(str(d) for d in base_shape),
",".join(np.dtype(dtype).name for dtype in arg_dtypes)),
"axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes}
for num_arrs in [3]
for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(-len(base_shape)+1, len(base_shape))))
def testConcatenate(self, axis, base_shape, arg_dtypes):
rng = jtu.rand_default(self.rng())
wrapped_axis = axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
def np_fun(*args):
args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32)
for x in args]
dtype = functools.reduce(jnp.promote_types, arg_dtypes)
return np.concatenate(args, axis=axis).astype(dtype)
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis)
def args_maker():
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for shape in [(4, 1), (4, 3), (4, 5, 6)]
for dtype in all_dtypes
for axis in [None] + list(range(1 - len(shape), len(shape) - 1))))
def testConcatenateArray(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda x: np.concatenate(x, axis=axis)
jnp_fun = lambda x: jnp.concatenate(x, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testConcatenateAxisNone(self):
# https://github.com/google/jax/issues/3419
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5]])
jnp.concatenate((a, b), axis=None)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
axis, ",".join(str(d) for d in base_shape),
",".join(np.dtype(dtype).name for dtype in arg_dtypes)),
"axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes}
for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, 2)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(-len(base_shape)+1, len(base_shape))))
def testAppend(self, axis, base_shape, arg_dtypes):
rng = jtu.rand_default(self.rng())
wrapped_axis = axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
def np_fun(arr, values):
arr = arr.astype(np.float32) if arr.dtype == jnp.bfloat16 else arr
values = (values.astype(np.float32) if values.dtype == jnp.bfloat16
else values)
out = np.append(arr, values, axis=axis)
return out.astype(jnp.promote_types(*arg_dtypes))
jnp_fun = lambda arr, values: jnp.append(arr, values, axis=axis)
def args_maker():
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_idx={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, idx),
"dtype": dtype, "shape": shape, "axis": axis, "idx": idx}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))
for idx in (range(-prod(shape), prod(shape))
if axis is None else
range(-shape[axis], shape[axis]))))
def testDeleteInteger(self, shape, dtype, idx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_slc={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, slc),
"dtype": dtype, "shape": shape, "axis": axis, "slc": slc}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))
for slc in [slice(None), slice(1, 3), slice(1, 5, 2)]))
def testDeleteSlice(self, shape, dtype, axis, slc):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, slc, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_idx={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis,
jtu.format_shape_dtype_string(idx_shape, int)),
"dtype": dtype, "shape": shape, "axis": axis, "idx_shape": idx_shape}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))
for idx_shape in all_shapes))
def testDeleteIndexArray(self, shape, dtype, axis, idx_shape):
rng = jtu.rand_default(self.rng())
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
# Previous to numpy 1.19, negative indices were ignored so we don't test this.
low = 0 if numpy_version < (1, 19, 0) else -max_idx
idx = jtu.rand_int(self.rng(), low=low, high=max_idx)(idx_shape, int)
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 19), "boolean mask not supported in numpy < 1.19.0")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"dtype": dtype, "shape": shape, "axis": axis}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testDeleteMaskArray(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool)
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, mask, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"dtype": dtype, "shape": shape, "axis": axis}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testInsertInteger(self, shape, dtype, axis):
x = jnp.empty(shape)
max_ind = x.size if axis is None else x.shape[axis]
rng = jtu.rand_default(self.rng())
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)]
np_fun = lambda *args: np.insert(*args, axis=axis)
jnp_fun = lambda *args: jnp.insert(*args, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"dtype": dtype, "shape": shape, "axis": axis}
for shape in nonempty_nonscalar_array_shapes
for dtype in all_dtypes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testInsertSlice(self, shape, dtype, axis):
x = jnp.empty(shape)
max_ind = x.size if axis is None else x.shape[axis]
rng = jtu.rand_default(self.rng())
i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind)
slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item())
args_maker = lambda: [rng(shape, dtype), rng((), dtype)]
np_fun = lambda x, val: np.insert(x, slc, val, axis=axis)
jnp_fun = lambda x, val: jnp.insert(x, slc, val, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.parameters([
[[[1, 1], [2, 2], [3, 3]], 1, 5, None],
[[[1, 1], [2, 2], [3, 3]], 1, 5, 1],
[[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1],
[[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1],
[[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None],
[[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None],
[[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None],
[[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1]
])
def testInsertExamples(self, arr, index, values, axis):
# Test examples from the np.insert docstring
args_maker = lambda: (
np.asarray(arr), index if isinstance(index, slice) else np.array(index),
np.asarray(values), axis)
self._CheckAgainstNumpy(np.insert, jnp.insert, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_out_dims={}".format(
jtu.format_shape_dtype_string(shape, dtype),
axis, out_dims),
"shape": shape, "dtype": dtype, "axis": axis, "out_dims": out_dims}
for shape in nonempty_array_shapes
for dtype in default_dtypes
for axis in range(-len(shape), len(shape))
for out_dims in [0, 1, 2]))
def testApplyAlongAxis(self, shape, dtype, axis, out_dims):
def func(x, out_dims):
if out_dims == 0:
return x.sum()
elif out_dims == 1:
return x * x[0]
elif out_dims == 2:
return x[:, None] + x[None, :]
else:
raise NotImplementedError(f"out_dims={out_dims}")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arr: np.apply_along_axis(func, axis, arr, out_dims=out_dims)
jnp_fun = lambda arr: jnp.apply_along_axis(func, axis, arr, out_dims=out_dims)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_func={}_keepdims={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype),
func, keepdims, axes),
"shape": shape, "dtype": dtype, "func": func, "keepdims": keepdims, "axes": axes}
for shape in nonempty_shapes
for func in ["sum"]
for keepdims in [True, False]
for axes in itertools.combinations(range(len(shape)), 2)
# Avoid low-precision types in sum()
for dtype in default_dtypes if dtype not in [np.float16, jnp.bfloat16]))
def testApplyOverAxes(self, shape, dtype, func, keepdims, axes):
f = lambda x, axis: getattr(x, func)(axis=axis, keepdims=keepdims)
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
np_fun = lambda a: np.apply_over_axes(f, a, axes)
jnp_fun = lambda a: jnp.apply_over_axes(f, a, axes)
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape=[{}]_axis={}_repeats={}_fixed_size={}".format(
jtu.format_shape_dtype_string(shape, dtype),
axis, repeats, fixed_size),
"axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats,
'fixed_size': fixed_size}
for repeats in [0, 1, 2]
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)
for axis in [None] + list(range(-len(shape), max(1, len(shape))))
for fixed_size in [True, False]))
def testRepeat(self, axis, shape, dtype, repeats, fixed_size):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis)
np_fun = _promote_like_jnp(np_fun)
if fixed_size:
total_repeat_length = np.repeat(np.zeros(shape), repeats, axis).shape[axis or 0]
jnp_fun = lambda arg, rep: jnp.repeat(arg, repeats=rep, axis=axis,
total_repeat_length=total_repeat_length)
jnp_args_maker = lambda: [rng(shape, dtype), repeats]
clo_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis,
total_repeat_length=total_repeat_length)
clo_fun_args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(jnp_fun, jnp_args_maker)
self._CheckAgainstNumpy(np_fun, clo_fun, clo_fun_args_maker)
else:
# Now repeats is in a closure, so a constant.
jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testRepeatScalarFastPath(self):
a = jnp.array([1,2,3,4])
f = lambda a: jnp.repeat(a, repeats=2)
jaxpr = jax.make_jaxpr(f)(a)
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_ind={}_inv={}_count={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis,
return_index, return_inverse, return_counts),
"shape": shape, "dtype": dtype, "axis": axis,
"return_index": return_index, "return_inverse": return_inverse,
"return_counts": return_counts}
for dtype in number_dtypes
for shape in all_shapes
for axis in [None] + list(range(len(shape)))
for return_index in [False, True]
for return_inverse in [False, True]
for return_counts in [False, True]))
def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts):
if axis is not None and numpy_version < (1, 19) and np.empty(shape).size == 0:
self.skipTest("zero-sized axis in unique leads to error in older numpy.")
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
extra_args = (return_index, return_inverse, return_counts)
use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False
np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults)
jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_size={}_fill_value={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, size, fill_value),
"shape": shape, "dtype": dtype, "axis": axis,
"size": size, "fill_value": fill_value}
for dtype in number_dtypes
for size in [1, 5, 10]
for fill_value in [None, -1.0, "slice"]
for shape in nonempty_array_shapes
for axis in [None] + list(range(len(shape)))))
def testUniqueSize(self, shape, dtype, axis, size, fill_value):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True)
if fill_value == "slice":
if axis is None:
fill_value = rng((), dtype)
else:
fill_value = rng(shape[:axis] + shape[axis + 1:], dtype)
@partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True))
def np_fun(x, fill_value=fill_value):
u, ind, inv, counts = np.unique(x, **kwds)
axis = kwds['axis']
if axis is None:
x = x.ravel()
axis = 0
n_unique = u.shape[axis]
if size <= u.shape[axis]:
slc = (slice(None),) * axis + (slice(size),)
u, ind, counts = u[slc], ind[:size], counts[:size]
else:
extra = (0, size - n_unique)
pads = [(0, 0)] * u.ndim
pads[axis] = extra
u = np.pad(u, pads, constant_values=0)
slices = [slice(None)] * u.ndim
slices[axis] = slice(1)
if fill_value is None:
fill_value = u[tuple(slices)]
elif np.ndim(fill_value):
fill_value = lax.expand_dims(fill_value, (axis,))
slices[axis] = slice(n_unique, None)
u[tuple(slices)] = fill_value
ind = np.pad(ind, extra, constant_values=ind[0])
counts = np.pad(counts, extra, constant_values=0)
return u, ind, inv, counts
jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@unittest.skipIf(numpy_version < (1, 21), "Numpy < 1.21 does not properly handle NaN values in unique.")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{dtype.__name__}", "dtype": dtype}
for dtype in inexact_dtypes))
def testUniqueNans(self, dtype):
def args_maker():
x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan]
if np.issubdtype(dtype, np.complexfloating):
x = [complex(i, j) for i, j in itertools.product(x, repeat=2)]
return [np.array(x, dtype=dtype)]
kwds = dict(return_index=True, return_inverse=True, return_counts=True)
jnp_fun = partial(jnp.unique, **kwds)
def np_fun(x):
dtype = x.dtype
# numpy unique fails for bfloat16 NaNs, so we cast to float64
if x.dtype == jnp.bfloat16:
x = x.astype('float64')
u, *rest = np.unique(x, **kwds)
return (u.astype(dtype), *rest)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_fixed_size={}".format(fixed_size),
"fixed_size": fixed_size}
for fixed_size in [True, False]))
def testNonScalarRepeats(self, fixed_size):
'''
Following numpy test suite from `test_repeat` at
https://github.com/numpy/numpy/blob/main/numpy/core/tests/test_multiarray.py
'''
tol = 1e-5
def test_single(m, args_maker, repeats, axis):
lax_ans = jnp.repeat(m, repeats, axis)
numpy_ans = np.repeat(m, repeats, axis)
self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol)
if fixed_size:
# Calculate expected size of the repeated axis.
rep_length = np.repeat(np.zeros_like(m), repeats, axis).shape[axis or 0]
jnp_fun = lambda arg, rep: jnp.repeat(
arg, repeats=rep, axis=axis, total_repeat_length=rep_length)
else:
jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis)
self._CompileAndCheck(jnp_fun, args_maker)
m = jnp.array([1,2,3,4,5,6])
if fixed_size:
args_maker = lambda: [m, repeats]
else:
args_maker = lambda: [m]
for repeats in [2, jnp.array([1,3,0,1,1,2]), jnp.array([1,3,2,1,1,2]), jnp.array([2])]:
test_single(m, args_maker, repeats, axis=None)
test_single(m, args_maker, repeats, axis=0)
m_rect = m.reshape((2,3))
if fixed_size:
args_maker = lambda: [m_rect, repeats]
else:
args_maker = lambda: [m_rect]
for repeats in [2, jnp.array([2,1]), jnp.array([2])]:
test_single(m_rect, args_maker, repeats, axis=0)
for repeats in [2, jnp.array([1,3,2]), jnp.array([2])]:
test_single(m_rect, args_maker, repeats, axis=1)
def testIssue2330(self):
'''
Make sure return value of jnp.concatenate is a jax.ndarray and is side-effect save
'''
def attempt_sideeffect(x):
x = [x]
x = jnp.concatenate(x)
x -= 1.
return x
np_input = np.ones((1))
jnp_input = jnp.ones((1))
expected_np_input_after_call = np.ones((1))
expected_jnp_input_after_call = jnp.ones((1))
self.assertTrue(device_array.type_is_device_array(jnp.concatenate([np_input])))
attempt_sideeffect(np_input)
attempt_sideeffect(jnp_input)
self.assertAllClose(np_input, expected_np_input_after_call)
self.assertAllClose(jnp_input, expected_jnp_input_after_call)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format(
op,
jtu.format_shape_dtype_string(xshape, dtype),
jtu.format_shape_dtype_string(yshape, dtype),
mode),
"xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode,
"jnp_op": getattr(jnp, op),
"np_op": getattr(np, op)}
for mode in ['full', 'same', 'valid']
for op in ['convolve', 'correlate']
for dtype in number_dtypes
for xshape in one_dim_array_shapes
for yshape in one_dim_array_shapes))
def testConvolutions(self, xshape, yshape, dtype, mode, jnp_op, np_op):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
precision = lax.Precision.HIGHEST if jtu.device_under_test() == "tpu" else None
np_fun = partial(np_op, mode=mode)
jnp_fun = partial(jnp_op, mode=mode, precision=precision)
tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14,
np.complex128: 1e-14}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format(
op, jtu.format_shape_dtype_string(shape, dtype), axis,
out_dtype.__name__),
"axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
"jnp_op": getattr(jnp, op), "np_op": getattr(np, op)}
for op in ["cumsum", "cumprod"]
for dtype in all_dtypes
for out_dtype in default_dtypes
for shape in all_shapes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testCumSumProd(self, axis, shape, dtype, out_dtype, np_op, jnp_op):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol_thresholds = {dtypes.bfloat16: 4e-2}
tol = max(jtu.tolerance(dtype, tol_thresholds),
jtu.tolerance(out_dtype, tol_thresholds))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format(
op, jtu.format_shape_dtype_string(shape, dtype), axis,
out_dtype.__name__),
"axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
"jnp_op": getattr(jnp, op), "np_op": getattr(np, op)}
for op in ["nancumsum", "nancumprod"]
for dtype in all_dtypes
for out_dtype in default_dtypes
for shape in all_shapes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testNanCumSumProd(self, axis, shape, dtype, out_dtype, np_op, jnp_op):
rng = jtu.rand_some_nan(self.rng())
np_fun = partial(np_op, axis=axis, dtype=out_dtype)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol_thresholds = {dtypes.bfloat16: 4e-2}
tol = max(jtu.tolerance(dtype, tol_thresholds),
jtu.tolerance(out_dtype, tol_thresholds))
if dtype != jnp.bfloat16:
# numpy functions do not properly handle bfloat16
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_yshape={}_xshape={}_dx={}_axis={}".format(
jtu.format_shape_dtype_string(yshape, dtype),
jtu.format_shape_dtype_string(xshape, dtype) if xshape is not None else None,
dx, axis),
"yshape": yshape, "xshape": xshape, "dtype": dtype, "dx": dx, "axis": axis}
for dtype in default_dtypes
for yshape, xshape, dx, axis in [
((10,), None, 1.0, -1),
((3, 10), None, 2.0, -1),
((3, 10), None, 3.0, -0),
((10, 3), (10,), 1.0, -2),
((3, 10), (10,), 1.0, -1),
((3, 10), (3, 10), 1.0, -1),
((2, 3, 10), (3, 10), 1.0, -2),
]))
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testTrapz(self, yshape, xshape, dtype, dx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
np_fun = partial(np.trapz, dx=dx, axis=axis)
jnp_fun = partial(jnp.trapz, dx=dx, axis=axis)
tol = jtu.tolerance(dtype, {np.float64: 1e-12,
dtypes.bfloat16: 4e-2})
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_m={}_n={}_k={}".format(
np.dtype(dtype).name, m, n, k),
"m": m, "n": n, "k": k, "dtype": dtype}
for dtype in default_dtypes
for n in [0, 4]
for m in [None, 0, 1, 3, 4]
for k in list(range(-4, 4))))
def testTri(self, m, n, k, dtype):
np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype)
jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_shape={}_k={}".format(
op, jtu.format_shape_dtype_string(shape, dtype), k),
"dtype": dtype, "shape": shape, "op": op, "k": k}
for dtype in default_dtypes
for shape in [shape for shape in all_shapes if len(shape) >= 2]
for op in ["tril", "triu"]
for k in list(range(-3, 3))))
def testTriLU(self, dtype, shape, op, k):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg: getattr(np, op)(arg, k=k)
jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "n={}_k={}_m={}".format(n, k, m),
"n": n, "k": k, "m": m}
for n in range(1, 5)
for k in [-1, 0, 1]
for m in range(1, 5)))
def testTrilIndices(self, n, k, m):
np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m)
jnp_fun = lambda n, k, m: jnp.tril_indices(n, k=k, m=m)
args_maker = lambda: [n, k, m]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "n={}_k={}_m={}".format(n, k, m),
"n": n, "k": k, "m": m}
for n in range(1, 5)
for k in [-1, 0, 1]
for m in range(1, 5)))
def testTriuIndices(self, n, k, m):
np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m)
jnp_fun = lambda n, k, m: jnp.triu_indices(n, k=k, m=m)
args_maker = lambda: [n, k, m]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"dtype": dtype, "shape": shape, "k": k}
for dtype in default_dtypes
for shape in [(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)]
for k in [-1, 0, 1]))
def testTriuIndicesFrom(self, shape, dtype, k):
rng = jtu.rand_default(self.rng())
np_fun = lambda arr, k: np.triu_indices_from(arr, k=k)
jnp_fun = lambda arr, k: jnp.triu_indices_from(arr, k=k)
args_maker = lambda: [rng(shape, dtype), k]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"dtype": dtype, "shape": shape, "k": k}
for dtype in default_dtypes
for shape in [(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)]
for k in [-1, 0, 1]))
def testTrilIndicesFrom(self, shape, dtype, k):
rng = jtu.rand_default(self.rng())
np_fun = lambda arr, k: np.tril_indices_from(arr, k=k)
jnp_fun = lambda arr, k: jnp.tril_indices_from(arr, k=k)
args_maker = lambda: [rng(shape, dtype), k]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_ndim={}_n={}".format(ndim, n),
"ndim": ndim, "n": n}
for ndim in [0, 1, 4]
for n in [0, 1, 7]))
def testDiagIndices(self, ndim, n):
np.testing.assert_equal(jtu.with_jax_dtype_defaults(np.diag_indices)(n, ndim),
jnp.diag_indices(n, ndim))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "arr_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)
),
"dtype": dtype, "shape": shape}
for dtype in default_dtypes
for shape in [(1,1), (2,2), (3,3), (4,4), (5,5)]))
def testDiagIndicesFrom(self, dtype, shape):
rng = jtu.rand_default(self.rng())
np_fun = jtu.with_jax_dtype_defaults(np.diag_indices_from)
jnp_fun = jnp.diag_indices_from
args_maker = lambda : [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"dtype": dtype, "shape": shape, "k": k}
for dtype in default_dtypes
for shape in [shape for shape in all_shapes if len(shape) in (1, 2)]
for k in list(range(-4, 4))))
def testDiag(self, shape, dtype, k):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg: np.diag(arg, k)
jnp_fun = lambda arg: jnp.diag(arg, k)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"dtype": dtype, "shape": shape, "k": k}
for dtype in default_dtypes
for shape in all_shapes
for k in range(-4, 4)))
def testDiagFlat(self, shape, dtype, k):
rng = jtu.rand_default(self.rng())
# numpy has inconsistencies for scalar values
# https://github.com/numpy/numpy/issues/16477
# jax differs in that it treats scalars values as length-1 arrays
np_fun = lambda arg: np.diagflat(np.atleast_1d(arg), k)
jnp_fun = lambda arg: jnp.diagflat(arg, k)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a1_shape={}_a2_shape2={}".format(
jtu.format_shape_dtype_string(a1_shape, dtype),
jtu.format_shape_dtype_string(a2_shape, dtype)),
"dtype": dtype, "a1_shape": a1_shape, "a2_shape": a2_shape}
for dtype in default_dtypes
for a1_shape in one_dim_array_shapes
for a2_shape in one_dim_array_shapes))
def testPolyMul(self, a1_shape, a2_shape, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1, arg2: np.polymul(arg1, arg2)
jnp_fun_np = lambda arg1, arg2: jnp.polymul(arg1, arg2, trim_leading_zeros=True)
jnp_fun_co = lambda arg1, arg2: jnp.polymul(arg1, arg2)
args_maker = lambda: [rng(a1_shape, dtype), rng(a2_shape, dtype)]
tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13}
self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "a_shape={} , b_shape={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),
jtu.format_shape_dtype_string(b_shape, dtype)),
"dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape}
for dtype in default_dtypes
for a_shape in one_dim_array_shapes
for b_shape in one_dim_array_shapes))
def testPolyDiv(self, a_shape, b_shape, dtype):
rng = jtu.rand_default(self.rng())
@jtu.ignore_warning(category=RuntimeWarning, message="divide by zero.*")
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
def np_fun(arg1, arg2):
q, r = np.polydiv(arg1, arg2)
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
r = np.pad(r, (1, 0), 'constant')
return q, r
def jnp_fun(arg1, arg2):
q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True)
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
r = jnp.pad(r, (1, 0), 'constant')
return q, r
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13}
jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format(
jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2),
"dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1,
"axis2": axis2}
for dtype in default_dtypes
for shape in [shape for shape in all_shapes if len(shape) >= 2]
for axis1 in range(-len(shape), len(shape))
for axis2 in [a for a in range(-len(shape), len(shape))
if a % len(shape) != axis1 % len(shape)]
for offset in list(range(-4, 4))))
def testDiagonal(self, shape, dtype, offset, axis1, axis2):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2)
jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}".format(np.dtype(dtype).name, n),
"dtype": dtype, "n": n}
for dtype in default_dtypes
for n in list(range(4))))
def testIdentity(self, n, dtype):
np_fun = lambda: np.identity(n, dtype)
jnp_fun = lambda: jnp.identity(n, dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_period={}_left={}_right={}".format(
jtu.format_shape_dtype_string(shape, dtype), period, left, right),
"shape": shape, "dtype": dtype,
"period": period, "left": left, "right": right}
for shape in nonempty_shapes
for period in [None, 0.59]
for left in [None, 0]
for right in [None, 1]
for dtype in default_dtypes
# following types lack precision for meaningful tests
if dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16]
))
def testInterp(self, shape, dtype, period, left, right):
rng = jtu.rand_default(self.rng(), scale=10)
kwds = dict(period=period, left=left, right=right)
np_fun = partial(np.interp, **kwds)
jnp_fun = partial(jnp.interp, **kwds)
args_maker = lambda: [rng(shape, dtype), np.sort(rng((20,), dtype)), np.linspace(0, 1, 20)]
# skip numpy comparison for integer types with period specified, because numpy
# uses an unstable sort and so results differ for duplicate values.
if not (period and np.issubdtype(dtype, np.integer)):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol={np.float32: 2E-4})
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x1={}_x2={}_x1_rng={}".format(
jtu.format_shape_dtype_string(x1_shape, x1_dtype),
jtu.format_shape_dtype_string(x2_shape, np.int32),
x1_rng_factory_id),
"x1_shape": x1_shape, "x1_dtype": x1_dtype,
"x2_shape": x2_shape, "x1_rng_factory": x1_rng_factory,
"x2_rng_factory": x2_rng_factory}
for x1_rng_factory_id, x1_rng_factory in
enumerate([jtu.rand_some_inf_and_nan, jtu.rand_some_zero])
for x2_rng_factory in [partial(jtu.rand_int, low=-1075, high=1024)]
for x1_shape, x2_shape in filter(_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(array_shapes, 2))
for x1_dtype in default_dtypes))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory):
# integer types are converted to float64 in numpy's implementation
if (x1_dtype not in [jnp.bfloat16, np.float16, np.float32]
and not config.x64_enabled):
self.skipTest("Only run float64 testcase when float64 is enabled.")
x1_rng = x1_rng_factory(self.rng())
x2_rng = x2_rng_factory(self.rng())
np_fun = lambda x1, x2: np.ldexp(x1, x2)
np_fun = jtu.ignore_warning(category=RuntimeWarning,
message="overflow.*")(np_fun)
jnp_fun = lambda x1, x2: jnp.ldexp(x1, x2)
args_maker = lambda: [x1_rng(x1_shape, x1_dtype),
x2_rng(x2_shape, np.int32)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_rng_factory={}".format(
jtu.format_shape_dtype_string(shape, dtype), rng_factory_id),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
for rng_factory_id, rng_factory in enumerate([
jtu.rand_some_inf_and_nan,
jtu.rand_some_zero,
partial(jtu.rand_not_small, offset=1e8),
])
for shape in all_shapes
for dtype in default_dtypes))
def testFrexp(self, shape, dtype, rng_factory):
# integer types are converted to float64 in numpy's implementation
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
and not config.x64_enabled):
self.skipTest("Only run float64 testcase when float64 is enabled.")
rng = rng_factory(self.rng())
np_fun = lambda x: np.frexp(x)
jnp_fun = lambda x: jnp.frexp(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=np.issubdtype(dtype, np.inexact))
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format(
jtu.format_shape_dtype_string(shape, dtype),
out_dtype, offset, axis1, axis2),
"dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset,
"axis1": axis1, "axis2": axis2}
for dtype in default_dtypes
for out_dtype in [None] + number_dtypes
for shape in [shape for shape in all_shapes if len(shape) >= 2]
for axis1 in range(-len(shape), len(shape))
for axis2 in range(-len(shape), len(shape))
if (axis1 % len(shape)) != (axis2 % len(shape))
for offset in list(range(-4, 4))))
def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2):
rng = jtu.rand_default(self.rng())
def np_fun(arg):
if out_dtype == jnp.bfloat16:
return np.trace(arg, offset, axis1, axis2, np.float32).astype(jnp.bfloat16)
else:
return np.trace(arg, offset, axis1, axis2, out_dtype)
jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_v={}_side={}".format(
jtu.format_shape_dtype_string(ashape, dtype),
jtu.format_shape_dtype_string(vshape, dtype),
side), "ashape": ashape, "vshape": vshape, "side": side,
"dtype": dtype}
for ashape in [(15,), (16,), (17,)]
for vshape in [(), (5,), (5, 5)]
for side in ['left', 'right']
for dtype in number_dtypes
))
def testSearchsorted(self, ashape, vshape, side, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)]
np_fun = lambda a, v: np.searchsorted(a, v, side=side)
jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype.__name__}_side={side}", "dtype": dtype, "side": side}
for dtype in inexact_dtypes
for side in ['left', 'right']))
def testSearchsortedNans(self, dtype, side):
if np.issubdtype(dtype, np.complexfloating):
raise SkipTest("Known failure for complex inputs; see #9107")
x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype)
# The sign bit should not matter for 0.0 or NaN, so argsorting the above should be
# equivalent to argsorting the following:
x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5])
if jnp.issubdtype(dtype, jnp.complexfloating):
x = np.array([complex(r, c) for r, c in itertools.product(x, repeat=2)])
x_equiv = np.array([complex(r, c) for r, c in itertools.product(x_equiv, repeat=2)])
fun = partial(jnp.searchsorted, side=side)
self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv))
self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_bins={}_right={}_reverse={}".format(
jtu.format_shape_dtype_string(xshape, dtype),
jtu.format_shape_dtype_string(binshape, dtype),
right, reverse), "xshape": xshape, "binshape": binshape,
"right": right, "reverse": reverse, "dtype": dtype}
for xshape in [(20,), (5, 4)]
for binshape in [(1,), (5,)]
for right in [True, False]
for reverse in [True, False]
for dtype in default_dtypes
))
def testDigitize(self, xshape, binshape, right, reverse, dtype):
order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:]
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
np_fun = lambda x, bins: np.digitize(x, bins, right=right)
jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_array={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input),
"shape": shape, "dtypes": dtypes, "array_input": array_input}
for dtypes in [
[np.float32],
[np.float32, np.float32],
[np.float32, np.int32, np.float32],
[np.float32, np.int64, np.float32],
[np.float32, np.int32, np.float64],
]
for shape in [(), (2,), (3, 4), (1, 5)]
for array_input in [True, False]))
def testColumnStack(self, shape, dtypes, array_input):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(np.column_stack)
jnp_fun = jnp.column_stack
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_array={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input),
"shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input}
for dtypes in [
[np.float32],
[np.float32, np.float32],
[np.float32, np.int32, np.float32],
[np.float32, np.int64, np.float32],
[np.float32, np.int32, np.float64],
]
for shape in [(), (2,), (3, 4), (1, 100)]
for axis in range(-len(shape), len(shape) + 1)
for array_input in [True, False]))
def testStack(self, shape, axis, dtypes, array_input):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(partial(np.stack, axis=axis))
jnp_fun = partial(jnp.stack, axis=axis)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_{}_array={}".format(
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input),
"shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input}
for op in ["hstack", "vstack", "dstack"]
for dtypes in [
[np.float32],
[np.float32, np.float32],
[np.float32, np.int32, np.float32],
[np.float32, np.int64, np.float32],
[np.float32, np.int32, np.float64],
]
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]
for array_input in [True, False]))
def testHVDStack(self, shape, op, dtypes, array_input):
rng = jtu.rand_default(self.rng())
if array_input:
args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])]
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
np_fun = _promote_like_jnp(getattr(np, op))
jnp_fun = getattr(jnp, op)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outdtype={}_fillshape={}".format(
jtu.format_shape_dtype_string(shape, fill_value_dtype),
np.dtype(out_dtype).name if out_dtype else "None",
fill_value_shape),
"fill_value_dtype": fill_value_dtype, "fill_value_shape": fill_value_shape,
"shape": shape, "out_dtype": out_dtype}
for shape in array_shapes + [3, np.array(7, dtype=np.int32)]
for fill_value_dtype in default_dtypes
for fill_value_shape in _compatible_shapes(shape)
for out_dtype in [None] + default_dtypes))
def testFull(self, shape, fill_value_dtype, fill_value_shape, out_dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype)
jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype)
args_maker = lambda: [rng(fill_value_shape, fill_value_dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "_shape={}_n={}_axis={}_prepend={}_append={}".format(
jtu.format_shape_dtype_string(shape, dtype),
n, axis, prepend, append),
"shape": shape, "dtype": dtype, "n": n, "axis": axis,
"prepend": prepend, "append": append
} for shape, dtype in s(_shape_and_dtypes(nonempty_nonscalar_array_shapes, default_dtypes))
for n in s([0, 1, 2])
for axis in s(list(range(-len(shape), max(1, len(shape)))))
for prepend in s([None, 1, np.zeros(shape, dtype=dtype)])
for append in s([None, 1, np.zeros(shape, dtype=dtype)])
)))
def testDiff(self, shape, dtype, n, axis, prepend, append):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
def np_fun(x, n=n, axis=axis, prepend=prepend, append=append):
if prepend is None:
prepend = np._NoValue
elif not np.isscalar(prepend) and prepend.dtype == jnp.bfloat16:
prepend = prepend.astype(np.float32)
if append is None:
append = np._NoValue
elif not np.isscalar(append) and append.dtype == jnp.bfloat16:
append = append.astype(np.float32)
if x.dtype == jnp.bfloat16:
return np.diff(x.astype(np.float32), n=n, axis=axis, prepend=prepend, append=append).astype(jnp.bfloat16)
else:
return np.diff(x, n=n, axis=axis, prepend=prepend, append=append)
jnp_fun = lambda x: jnp.diff(x, n=n, axis=axis, prepend=prepend, append=append)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_op={}_shape={}_dtype={}").format(op, shape, dtype),
"np_op": getattr(np, op), "jnp_op": getattr(jnp, op),
"shape": shape, "dtype": dtype}
for op in ["zeros", "ones"]
for shape in [2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32),
np.array(4, dtype=np.int32)]
for dtype in all_dtypes))
def testZerosOnes(self, np_op, jnp_op, shape, dtype):
args_maker = lambda: []
np_op = partial(np_op, shape, dtype)
jnp_op = partial(jnp_op, shape, dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testOnesWithInvalidShape(self):
with self.assertRaises(TypeError):
jnp.ones((-1, 1))
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format(
jtu.format_shape_dtype_string(shape, in_dtype),
np.dtype(fill_value_dtype).name, fill_value_shape,
np.dtype(out_dtype).name, out_shape),
"shape": shape, "in_dtype": in_dtype,
"fill_value_dtype": fill_value_dtype, "fill_value_shape": fill_value_shape,
"out_dtype": out_dtype, "out_shape": out_shape
} for shape in s(array_shapes)
for out_shape in s([None] + array_shapes)
for in_dtype in s(default_dtypes)
for fill_value_dtype in s(default_dtypes)
for fill_value_shape in s(_compatible_shapes(shape if out_shape is None else out_shape))
for out_dtype in s(default_dtypes))))
def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape):
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
np_fun = lambda x, fill_value: np.full_like(
x, fill_value, dtype=out_dtype, shape=out_shape)
jnp_fun = lambda x, fill_value: jnp.full_like(
x, fill_value, dtype=out_dtype, shape=out_shape)
args_maker = lambda: [rng(shape, in_dtype), rng(fill_value_shape, fill_value_dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format(
func, jtu.format_shape_dtype_string(shape, in_dtype),
out_shape, out_dtype),
"func": func, "shape": shape, "in_dtype": in_dtype,
"out_shape": out_shape, "out_dtype": out_dtype}
for shape in array_shapes
for out_shape in [None] + array_shapes
for in_dtype in default_dtypes
for func in ["ones_like", "zeros_like"]
for out_dtype in default_dtypes))
def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype):
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape)
jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape)
args_maker = lambda: [rng(shape, in_dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format(
func, jtu.format_shape_dtype_string(shape, in_dtype),
weak_type, out_shape, out_dtype),
"func": func, "args": args,
"shape": shape, "in_dtype": in_dtype, "weak_type": weak_type,
"out_shape": out_shape, "out_dtype": out_dtype}
for shape in array_shapes
for in_dtype in [np.int32, np.float32, np.complex64]
for weak_type in [True, False]
for out_shape in [None, (), (10,)]
for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())]
for out_dtype in [None, float]))
def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype):
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
x = lax_internal._convert_element_type(rng(shape, in_dtype),
weak_type=weak_type)
fun = lambda x: getattr(jnp, func)(x, *args, dtype=out_dtype, shape=out_shape)
expected_weak_type = weak_type and (out_dtype is None)
self.assertEqual(dtypes.is_weakly_typed(fun(x)), expected_weak_type)
self.assertEqual(dtypes.is_weakly_typed(jax.jit(fun)(x)), expected_weak_type)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_funcname={}_input_type={}_val={}_dtype={}".format(
funcname, input_type, val, dtype),
"funcname": funcname, "input_type": input_type, "val": val, "dtype": dtype}
for funcname in ["array", "asarray"]
for dtype in [int, float, None]
for val in [0, 1]
for input_type in [int, float, np.int32, np.float32]))
def testArrayWeakType(self, funcname, input_type, val, dtype):
func = lambda x: getattr(jnp, funcname)(x, dtype=dtype)
fjit = jax.jit(func)
val = input_type(val)
expected_weak_type = dtype is None and input_type in set(dtypes._weak_types)
self.assertEqual(dtypes.is_weakly_typed(func(val)), expected_weak_type)
self.assertEqual(dtypes.is_weakly_typed(fjit(val)), expected_weak_type)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_weak_type={}_slc={}".format(
jtu.format_shape_dtype_string(shape, dtype), weak_type, slc),
"shape": shape, "dtype": dtype, "weak_type": weak_type, "slc": slc}
for shape in nonempty_nonscalar_array_shapes
for dtype in [int, float, complex]
for weak_type in [True, False]
for slc in [slice(None), slice(0), slice(3), 0, ...]))
def testSliceWeakTypes(self, shape, dtype, weak_type, slc):
rng = jtu.rand_default(self.rng())
x = lax_internal._convert_element_type(rng(shape, dtype),
weak_type=weak_type)
op = lambda x: x[slc]
self.assertEqual(op(x).aval.weak_type, weak_type)
self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_{}sections".format(
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
"shape": shape, "num_sections": num_sections, "axis": axis,
"dtype": dtype}
for shape, axis, num_sections in [
((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
for dtype in default_dtypes))
def testSplitStaticInt(self, shape, num_sections, axis, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.split(x, num_sections, axis=axis)
jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_{}sections".format(
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
"shape": shape, "num_sections": num_sections, "axis": axis, "dtype": dtype}
# All testcases split the specified axis unequally
for shape, axis, num_sections in [
((3,), 0, 2), ((12,), 0, 5), ((12, 4), 0, 7), ((12, 4), 1, 3),
((2, 3, 5), -1, 2), ((2, 4, 4), -2, 3), ((7, 2, 2), 0, 3)]
for dtype in default_dtypes))
def testArraySplitStaticInt(self, shape, num_sections, axis, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.array_split(x, num_sections, axis=axis)
jnp_fun = lambda x: jnp.array_split(x, num_sections, axis=axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testSplitTypeError(self):
# If we pass an ndarray for indices_or_sections -> no error
self.assertEqual(3, len(jnp.split(jnp.zeros(3), jnp.array([1, 2]))))
CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected."
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
# An abstract tracer for idx
jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), idx))(2.)
with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG):
# A list including an abstract tracer
jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), [2, idx]))(2.)
# A concrete tracer -> no error
jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), idx),
(2.,), (1.,))
# A tuple including a concrete tracer -> no error
jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), (1, idx)),
(2.,), (1.,))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_bins={}_range={}_weights={}".format(
jtu.format_shape_dtype_string(shape, dtype), bins, range, weights),
"shape": shape,
"dtype": dtype,
"bins": bins,
"range": range,
"weights": weights,
}
for shape in [(5,), (5, 5)]
for dtype in number_dtypes
for bins in [10, np.arange(-5, 6), np.array([-5, 0, 3])]
for range in [None, (0, 0), (0, 10)]
for weights in [True, False]
))
def testHistogramBinEdges(self, shape, dtype, bins, range, weights):
rng = jtu.rand_default(self.rng())
_weights = lambda w: abs(w) if weights else None
np_fun = lambda a, w, r: np.histogram_bin_edges(a, bins=bins, range=r,
weights=_weights(w))
jnp_fun = lambda a, w, r: jnp.histogram_bin_edges(a, bins=bins, range=r,
weights=_weights(w))
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), range]
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-2}
# linspace() compares poorly to numpy when using bfloat16
if dtype != jnp.bfloat16:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker,
atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_bins={}_density={}_weights={}".format(
jtu.format_shape_dtype_string(shape, dtype), bins, density, weights),
"shape": shape,
"dtype": dtype,
"bins": bins,
"density": density,
"weights": weights,
}
for shape in [(5,), (5, 5)]
for dtype in default_dtypes
# We only test explicit integer-valued bin edges because in other cases
# rounding errors lead to flaky tests.
for bins in [np.arange(-5, 6), np.array([-5, 0, 3])]
for density in [True, False]
for weights in [True, False]
))
def testHistogram(self, shape, dtype, bins, density, weights):
rng = jtu.rand_default(self.rng())
_weights = lambda w: abs(w) if weights else None
np_fun = lambda a, w: np.histogram(a, bins=bins, density=density,
weights=_weights(w))
jnp_fun = lambda a, w: jnp.histogram(a, bins=bins, density=density,
weights=_weights(w))
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1}
# np.searchsorted errors on bfloat16 with
# "TypeError: invalid type promotion with custom data type"
if dtype != jnp.bfloat16:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_bins={}_weights={}_density={}_range={}".format(
jtu.format_shape_dtype_string(shape, dtype), bins, weights, density, range),
"shape": shape, "dtype": dtype, "bins": bins, "weights": weights, "density": density, "range": range,
}
for shape in [(5,), (12,)]
for dtype in int_dtypes
for bins in [2, [2, 2], [np.array([0, 1, 3, 5]), np.array([0, 2, 3, 4, 6])]]
for weights in [False, True]
for density in [False, True]
for range in [None, [(-1, 1), None], [(-1, 1), (-2, 2)]]
))
def testHistogram2d(self, shape, dtype, bins, weights, density, range):
rng = jtu.rand_default(self.rng())
_weights = lambda w: abs(w) if weights else None
np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")(
lambda a, b, w: np.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range))
jnp_fun = lambda a, b, w: jnp.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range)
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)]
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1}
# np.searchsorted errors on bfloat16 with
# "TypeError: invalid type promotion with custom data type"
with np.errstate(divide='ignore', invalid='ignore'):
if dtype != jnp.bfloat16:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_bins={}_weights={}_density={}_range={}".format(
jtu.format_shape_dtype_string(shape, dtype), bins, weights, density, range),
"shape": shape, "dtype": dtype, "bins": bins, "weights": weights, "density": density, "range": range,
}
for shape in [(5, 3), (10, 3)]
for dtype in int_dtypes
for bins in [(2, 2, 2), [np.array([-5, 0, 4]), np.array([-4, -1, 2]), np.array([-6, -1, 4])]]
for weights in [False, True]
for density in [False, True]
for range in [None, [(-1, 1), None, None], [(-1, 1), (-2, 2), (-3, 3)]]
))
def testHistogramdd(self, shape, dtype, bins, weights, density, range):
rng = jtu.rand_default(self.rng())
_weights = lambda w: abs(w) if weights else None
np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")(
lambda a, w: np.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range))
jnp_fun = lambda a, w: jnp.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range)
args_maker = lambda: [rng(shape, dtype), rng((shape[0],), dtype)]
tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1}
# np.searchsorted errors on bfloat16 with
# "TypeError: invalid type promotion with custom data type"
if dtype != jnp.bfloat16:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_{}sections".format(
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
"shape": shape, "num_sections": num_sections, "axis": axis,
"dtype": dtype}
for shape, axis, num_sections in [
((12, 4), 0, 4), ((12, 4), 1, 2),
((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)]
for dtype in default_dtypes))
def testHVDSplit(self, shape, num_sections, axis, dtype):
rng = jtu.rand_default(self.rng())
def fn(module, axis):
if axis == 0:
return module.vsplit
elif axis == 1:
return module.hsplit
else:
assert axis == 2
return module.dsplit
np_fun = lambda x: fn(np, axis)(x, num_sections)
jnp_fun = lambda x: fn(jnp, axis)(x, num_sections)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_order={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype),
order),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
"order": order}
for dtype in default_dtypes
for order in ["C", "F"]
for arg_shape, out_shape in [
(jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)),
((), (1, 1, 1)),
((7, 0), (0, 42, 101)),
((3, 4), 12),
((3, 4), (12,)),
((3, 4), -1),
((2, 1, 4), (-1,)),
((2, 2, 4), (2, 8))
]))
def testReshape(self, arg_shape, out_shape, dtype, order):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.reshape(x, out_shape, order=order)
jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype)),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype}
for dtype in default_dtypes
for arg_shape, out_shape in [
((7, 0), (0, 42, 101)),
((2, 1, 4), (-1,)),
((2, 2, 4), (2, 8))
]))
def testReshapeMethod(self, arg_shape, out_shape, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.reshape(x, out_shape)
jnp_fun = lambda x: x.reshape(*out_shape)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype)),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype}
for dtype in default_dtypes
for arg_shape, out_shape in itertools.product(all_shapes, array_shapes)))
def testResize(self, arg_shape, out_shape, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.resize(x, out_shape)
jnp_fun = lambda x: jnp.resize(x, out_shape)
args_maker = lambda: [rng(arg_shape, dtype)]
if len(out_shape) > 0 or numpy_version >= (1, 20, 0):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_expanddim={!r}".format(
jtu.format_shape_dtype_string(arg_shape, dtype), dim),
"arg_shape": arg_shape, "dtype": dtype, "dim": dim}
for arg_shape in [(), (3,), (3, 4)]
for dtype in default_dtypes
for dim in (list(range(-len(arg_shape)+1, len(arg_shape)))
+ [np.array(0), np.array(-1), (0,), [np.array(0)],
(len(arg_shape), len(arg_shape) + 1)])))
def testExpandDimsStaticDim(self, arg_shape, dtype, dim):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.expand_dims(x, dim)
jnp_fun = lambda x: jnp.expand_dims(x, dim)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
def testExpandDimsRepeatedAxisError(self):
x = jnp.ones((2, 3))
self.assertRaisesRegex(
ValueError, 'repeated axis.*',
lambda: jnp.expand_dims(x, [1, 1]))
self.assertRaisesRegex(
ValueError, 'repeated axis.*',
lambda: jnp.expand_dims(x, [3, -1]))
# ensure this is numpy's behavior too, so that we remain consistent
x = np.ones((2, 3))
self.assertRaisesRegex(
ValueError, 'repeated axis.*',
lambda: np.expand_dims(x, [1, 1]))
self.assertRaisesRegex(
ValueError, 'repeated axis.*',
lambda: np.expand_dims(x, [3, -1]))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_axes=({},{})".format(
jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
"arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2}
for arg_shape, ax1, ax2 in [
((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
for dtype in default_dtypes))
def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.swapaxes(x, ax1, ax2)
jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_axis={!r}".format(
jtu.format_shape_dtype_string(arg_shape, dtype), ax),
"arg_shape": arg_shape, "dtype": dtype, "ax": ax}
for arg_shape, ax in [
((3, 1), None),
((3, 1), 1),
((3, 1), -1),
((3, 1), np.array(1)),
((1, 3, 1), (0, 2)),
((1, 3, 1), (0,)),
((1, 4, 1), (np.array(0),))]
for dtype in default_dtypes))
def testSqueeze(self, arg_shape, dtype, ax):
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.squeeze(x, ax)
jnp_fun = lambda x: jnp.squeeze(x, ax)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format(
jtu.format_shape_dtype_string(shape, dtype),
axis,
(None if weights_shape is None else jtu.format_shape_dtype_string(weights_shape, dtype)),
returned),
"shape": shape, "dtype": dtype, "axis": axis,
"weights_shape": weights_shape, "returned": returned}
for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes)
for axis in list(range(-len(shape), len(shape))) + [None]
# `weights_shape` is either `None`, same as the averaged axis, or same as
# that of the input
for weights_shape in ([None, shape] if axis is None or len(shape) == 1
else [None, (shape[axis],), shape])
for returned in [False, True]))
def testAverage(self, shape, dtype, axis, weights_shape, returned):
rng = jtu.rand_default(self.rng())
if weights_shape is None:
np_fun = lambda x: np.average(x, axis, returned=returned)
jnp_fun = lambda x: jnp.average(x, axis, returned=returned)
args_maker = lambda: [rng(shape, dtype)]
else:
np_fun = lambda x, weights: np.average(x, axis, weights, returned)
jnp_fun = lambda x, weights: jnp.average(x, axis, weights, returned)
args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)]
np_fun = _promote_like_jnp(np_fun, inexact=True)
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5,
np.float64: 1e-12, np.complex64: 1e-5}
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)
except ZeroDivisionError:
self.skipTest("don't support checking for ZeroDivisionError")
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
f"_arg{i}_ndmin={ndmin}_dtype={np.dtype(dtype) if dtype else None}",
"arg": arg, "ndmin": ndmin, "dtype": dtype}
for i, (arg, dtypes) in enumerate([
([True, False, True], all_dtypes),
(3., all_dtypes),
([1, 2, 3], all_dtypes),
(np.array([1, 2, 3], dtype=np.int64), all_dtypes),
([1., 2., 3.], all_dtypes),
([[1, 2], [3, 4], [5, 6]], all_dtypes),
([[1, 2.], [3, 4], [5, 6]], all_dtypes),
([[1., 2j], [3., 4.], [5., 6.]], complex_dtypes),
([[3, np.array(2, dtype=jnp.float_), 1],
np.arange(3., dtype=jnp.float_)], all_dtypes),
])
for dtype in [None] + dtypes
for ndmin in [None, np.ndim(arg), np.ndim(arg) + 1, np.ndim(arg) + 2]))
def testArray(self, arg, ndmin, dtype):
args_maker = lambda: [arg]
canonical_dtype = dtypes.canonicalize_dtype(dtype or np.array(arg).dtype)
if ndmin is not None:
np_fun = partial(np.array, ndmin=ndmin, dtype=canonical_dtype)
jnp_fun = partial(jnp.array, ndmin=ndmin, dtype=dtype)
else:
np_fun = partial(np.array, dtype=canonical_dtype)
jnp_fun = partial(jnp.array, dtype=dtype)
# We are testing correct canonicalization behavior here, so we turn off the
# permissive canonicalization logic in the test harness.
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
canonicalize_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*")
def testArrayDtypeInference(self):
def _check(obj, out_dtype, weak_type):
dtype_reference = np.array(obj, dtype=out_dtype)
out = jnp.array(obj)
self.assertDtypesMatch(out, dtype_reference)
self.assertEqual(dtypes.is_weakly_typed(out), weak_type)
out_jit = jax.jit(jnp.array)(obj)
self.assertDtypesMatch(out_jit, dtype_reference)
self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type)
# Python scalars become 64-bit weak types.
_check(1, np.int64, True)
_check(1.0, np.float64, True)
_check(1.0j, np.complex128, True)
# Lists become strongly-typed defaults.
_check([1], jnp.int_, False)
_check([1.0], jnp.float_, False)
_check([1.0j], jnp.complex_, False)
# Lists of weakly-typed objects become strongly-typed defaults.
_check([jnp.array(1)], jnp.int_, False)
_check([jnp.array(1.0)], jnp.float_, False)
_check([jnp.array(1.0j)], jnp.complex_, False)
# Lists of strongly-typed objects maintain their strong type.
_check([jnp.int64(1)], np.int64, False)
_check([jnp.float64(1)], np.float64, False)
_check([jnp.complex128(1)], np.complex128, False)
# Mixed inputs use JAX-style promotion.
# (regression test for https://github.com/google/jax/issues/8945)
_check([0, np.int16(1)], np.int16, False)
_check([0.0, np.float16(1)], np.float16, False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={np.dtype(dtype)}_func={func}",
"dtype": dtype, "func": func}
for dtype in all_dtypes
for func in ["array", "copy"]))
def testArrayCopy(self, dtype, func):
x = jnp.ones(10, dtype=dtype)
copy_func = getattr(jnp, func)
x_view = jnp.asarray(x)
x_view_jit = jax.jit(jnp.asarray)(x)
x_copy = copy_func(x)
x_copy_jit = jax.jit(copy_func)(x)
_ptr = lambda x: x.device_buffer.unsafe_buffer_pointer()
self.assertEqual(_ptr(x), _ptr(x_view))
self.assertEqual(_ptr(x), _ptr(x_view_jit))
self.assertNotEqual(_ptr(x), _ptr(x_copy))
self.assertNotEqual(_ptr(x), _ptr(x_copy_jit))
x.delete()
self.assertTrue(x_view.is_deleted())
self.assertTrue(x_view_jit.is_deleted())
self.assertFalse(x_copy.is_deleted())
self.assertFalse(x_copy_jit.is_deleted())
def testArrayCopyAutodiff(self):
f = lambda x: jnp.array(x, copy=True)
x = jnp.ones(10)
xdot = jnp.ones(10)
y, ydot = jax.jvp(f, (x,), (xdot,))
self.assertIsNot(x, y)
self.assertIsNot(xdot, ydot)
ybar = jnp.ones(10)
y, f_vjp = jax.vjp(f, x)
xbar, = f_vjp(ybar)
self.assertIsNot(x, y)
self.assertIsNot(xbar, ybar)
def testArrayCopyVmap(self):
f = lambda x: jnp.array(x, copy=True)
x = jnp.ones(10)
y = jax.vmap(f)(x)
self.assertIsNot(x, y)
def testArrayUnsupportedDtypeError(self):
with self.assertRaisesRegex(TypeError,
"JAX only supports number and bool dtypes.*"):
jnp.array(3, [('a','<i4'),('b','<i4')])
def testArrayFromInteger(self):
int_dtype = dtypes.canonicalize_dtype(jnp.int64)
int_max = jnp.iinfo(int_dtype).max
int_min = jnp.iinfo(int_dtype).min
# Values at extremes are converted correctly.
for val in [int_min, 0, int_max]:
self.assertEqual(jnp.array(val).dtype, int_dtype)
# out of bounds leads to an OverflowError
val = int_max + 1
with self.assertRaisesRegex(OverflowError, f"Python int {val} too large to convert to {int_dtype.name}"):
jnp.array(val)
# explicit uint64 should work
if config.x64_enabled:
self.assertEqual(np.uint64(val), jnp.array(val, dtype='uint64'))
def testArrayFromList(self):
int_max = jnp.iinfo(jnp.int64).max
int_min = jnp.iinfo(jnp.int64).min
# Values at extremes are converted correctly.
for val in [int_min, 0, int_max]:
self.assertEqual(jnp.array([val]).dtype, dtypes.canonicalize_dtype('int64'))
# list of values results in promoted type.
self.assertEqual(jnp.array([0, np.float16(1)]).dtype, jnp.result_type('int64', 'float16'))
# out of bounds leads to an OverflowError
val = int_min - 1
with self.assertRaisesRegex(OverflowError, "Python int too large.*"):
jnp.array([0, val])
def testIssue121(self):
assert not np.isscalar(jnp.array(3))
def testArrayOutputsDeviceArrays(self):
assert device_array.type_is_device_array(jnp.array([]))
assert device_array.type_is_device_array(jnp.array(np.array([])))
class NDArrayLike:
def __array__(self, dtype=None):
return np.array([], dtype=dtype)
assert device_array.type_is_device_array(jnp.array(NDArrayLike()))
# NOTE(mattjj): disabled b/c __array__ must produce ndarrays
# class DeviceArrayLike:
# def __array__(self, dtype=None):
# return jnp.array([], dtype=dtype)
# assert xla.type_is_device_array(jnp.array(DeviceArrayLike()))
def testArrayMethod(self):
class arraylike(object):
dtype = np.dtype('float32')
def __array__(self, dtype=None):
return np.array(3., dtype=dtype)
a = arraylike()
ans = jnp.array(a)
self.assertEqual(ans, 3.)
def testJaxArrayOps(self):
class arraylike:
def __jax_array__(self):
return jnp.array(3.)
self.assertArraysEqual(arraylike() * jnp.arange(10), jnp.array(3.) * jnp.arange(10))
def testMemoryView(self):
self.assertAllClose(
jnp.array(bytearray(b'\x2a')),
np.array(bytearray(b'\x2a'))
)
self.assertAllClose(
jnp.array(bytearray(b'\x2a\xf3'), ndmin=2),
np.array(bytearray(b'\x2a\xf3'), ndmin=2)
)
def testIsClose(self):
c_isclose = jax.jit(jnp.isclose)
c_isclose_nan = jax.jit(partial(jnp.isclose, equal_nan=True))
n = 2
rng = self.rng()
x = rng.randn(n, 1)
y = rng.randn(n, 1)
inf = np.asarray(n * [np.inf]).reshape([n, 1])
nan = np.asarray(n * [np.nan]).reshape([n, 1])
args = [x, y, inf, -inf, nan]
for arg0 in args:
for arg1 in args:
result_np = np.isclose(arg0, arg1)
result_jax = jnp.isclose(arg0, arg1)
result_jit = c_isclose(arg0, arg1)
self.assertTrue(jnp.all(jnp.equal(result_np, result_jax)))
self.assertTrue(jnp.all(jnp.equal(result_np, result_jit)))
result_np = np.isclose(arg0, arg1, equal_nan=True)
result_jax = jnp.isclose(arg0, arg1, equal_nan=True)
result_jit = c_isclose_nan(arg0, arg1)
self.assertTrue(jnp.all(jnp.equal(result_np, result_jax)))
self.assertTrue(jnp.all(jnp.equal(result_np, result_jit)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_y={}_equal_nan={}".format(x, y, equal_nan),
"x": x, "y": y, "equal_nan": equal_nan}
for x, y in itertools.product([
1, [1], [1, 1 + 1E-4], [1, np.nan]], repeat=2)
for equal_nan in [True, False]))
def testAllClose(self, x, y, equal_nan):
jnp_fun = partial(jnp.allclose, equal_nan=equal_nan, rtol=1E-3)
np_fun = partial(np.allclose, equal_nan=equal_nan, rtol=1E-3)
args_maker = lambda: [np.array(x), np.array(y)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testZeroStridesConstantHandler(self):
raw_const = self.rng().randn(1, 2, 1, 1, 5, 1)
const = np.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))
def fun(x):
return x * const
fun = jax.jit(fun)
out_val = fun(3.)
self.assertAllClose(out_val, 3. * const, check_dtypes=False)
def testIsInstanceNdarrayDuringTracing(self):
arr = np.ones(3)
@jax.jit
def f(x):
self.assertIsInstance(x, jnp.ndarray)
return jnp.sum(x)
f(arr)
def testNonArrayErrorMessage(self):
x = [1., 2.]
y = np.array([3., 4.])
def g(x, y):
return jnp.add(x, y)
def f(x, y):
return jnp.dot(x, y)
self.assertRaises(TypeError, lambda: g(x, y))
self.assertRaises(TypeError, lambda: f(x, y))
self.assertRaises(TypeError, lambda: jax.jit(g)(x, y))
self.assertRaises(TypeError, lambda: jax.jit(f)(x, y))
def testAbstractionErrorMessage(self):
@jax.jit
def f(x, n):
for _ in range(n):
x = x * x
return x
self.assertRaises(jax.errors.TracerIntegerConversionError, lambda: f(3., 3))
@jax.jit
def g(x):
if x > 0.:
return x * 2
else:
return x + 2
self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.))
def testTracingPrimitiveWithNoTranslationErrorMessage(self):
# TODO(mattjj): update this for jax3
self.skipTest("test needs jax3 update")
foo = jnp._not_implemented(lambda x: x)
# No error if there's no tracing.
foo(np.arange(3))
cfoo = jax.jit(foo)
self.assertRaises(NotImplementedError, lambda: cfoo(np.arange(3)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for shape in [(3,), (2, 3)]
for dtype in default_dtypes
for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))] # Test negative axes and tuples
))
def testFlip(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.flip(x, axis)
np_op = lambda x: np.flip(x, axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(3,), (2, 3), (3, 2, 4)]
for dtype in default_dtypes))
def testFlipud(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.flipud(x)
np_op = lambda x: np.flipud(x)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(3, 2), (2, 3), (3, 2, 4)]
for dtype in default_dtypes))
def testFliplr(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.fliplr(x)
np_op = lambda x: np.fliplr(x)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_k={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, axes),
"shape": shape, "dtype": dtype, "k": k, "axes": axes}
for shape, axes in [
[(2, 3), (0, 1)],
[(2, 3), (1, 0)],
[(4, 3, 2), (0, 2)],
[(4, 3, 2), (2, 1)],
]
for k in range(-3, 4)
for dtype in default_dtypes))
def testRot90(self, shape, dtype, k, axes):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_op = lambda x: jnp.rot90(x, k, axes)
np_op = lambda x: np.rot90(x, k, axes)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
# TODO(mattjj): test infix operator overrides
def testRavel(self):
rng = self.rng()
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
self._CompileAndCheck(lambda x: x.ravel(), args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_order={}_mode={}".format(
shape, order, mode),
"shape": shape, "order": order, "mode": mode}
for shape in nonempty_nonscalar_array_shapes
for order in ['C', 'F']
for mode in ['wrap', 'clip', 'raise']))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testRavelMultiIndex(self, shape, order, mode):
# generate indices in each dimension with a few out of bounds.
rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1)
for dim in shape]
# generate multi_indices of different dimensions that broadcast.
args_maker = lambda: [tuple(rng(ndim * (3,), jnp.int_)
for ndim, rng in enumerate(rngs))]
def np_fun(x):
try:
return np.ravel_multi_index(x, shape, order=order, mode=mode)
except ValueError as err:
if str(err).startswith('invalid entry'):
# sentinel indicating expected error.
return -999
else:
raise
def jnp_fun(x):
try:
return jnp.ravel_multi_index(x, shape, order=order, mode=mode)
except ValueError as err:
if str(err).startswith('invalid entry'):
# sentinel indicating expected error.
return -999
else:
raise
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
if mode == 'raise':
msg = ("The error occurred because ravel_multi_index was jit-compiled "
"with mode='raise'. Use mode='wrap' or mode='clip' instead.")
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg):
jax.jit(jnp_fun)(*args_maker())
else:
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_ashape={}{}_cshapes={}{}_mode={}".format(
adtype.__name__, ashape, cdtype.__name__, cshapes, mode),
"ashape": ashape, "adtype": adtype, "cshapes": cshapes, "cdtype": cdtype, "mode": mode}
for ashape in ((), (4,), (3, 4))
for cshapes in [
[(), (4,)],
[(3, 4), (4,), (3, 1)]
]
for adtype in int_dtypes
for cdtype in default_dtypes
for mode in ['wrap', 'clip', 'raise']))
def testChoose(self, ashape, adtype, cshapes, cdtype, mode):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(ashape, adtype), [rng(s, cdtype) for s in cshapes]]
def np_fun(a, c):
try:
return np.choose(a, c, mode=mode)
except ValueError as err:
if mode == 'raise' and str(err).startswith('invalid entry'):
return -999 # sentinel indicating expected error.
else:
raise
def jnp_fun(a, c):
try:
return jnp.choose(a, c, mode=mode)
except ValueError as err:
if mode == 'raise' and str(err).startswith('invalid entry'):
return -999 # sentinel indicating expected error.
else:
raise
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
if mode == 'raise':
msg = ("The error occurred because jnp.choose was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg):
jax.jit(jnp_fun)(*args_maker())
else:
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.parameters(
(0, (2, 1, 3)),
(5, (2, 1, 3)),
(0, ()),
(np.array([0, 1, 2]), (2, 2)),
(np.array([[[0, 1], [2, 3]]]), (2, 2)))
def testUnravelIndex(self, flat_index, shape):
args_maker = lambda: (flat_index, shape)
np_fun = jtu.with_jax_dtype_defaults(np.unravel_index, use_defaults=not hasattr(flat_index, 'dtype'))
self._CheckAgainstNumpy(np_fun, jnp.unravel_index, args_maker)
self._CompileAndCheck(jnp.unravel_index, args_maker)
def testUnravelIndexOOB(self):
self.assertEqual(jnp.unravel_index(2, (2,)), (1,))
self.assertEqual(jnp.unravel_index(-2, (2, 1, 3,)), (1, 0, 1))
self.assertEqual(jnp.unravel_index(-3, (2,)), (0,))
def testAstype(self):
rng = self.rng()
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
np_op = lambda x: np.asarray(x).astype(jnp.int32)
jnp_op = lambda x: jnp.asarray(x).astype(jnp.int32)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testAstypeNone(self):
rng = self.rng()
args_maker = lambda: [rng.randn(3, 4).astype("int32")]
np_op = jtu.with_jax_dtype_defaults(lambda x: np.asarray(x).astype(None))
jnp_op = lambda x: jnp.asarray(x).astype(None)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in array_shapes
for dtype in all_dtypes))
def testNbytes(self, shape, dtype):
rng = jtu.rand_default(self.rng())
np_op = lambda x: np.asarray(x).nbytes
jnp_op = lambda x: jnp.asarray(x).nbytes
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in array_shapes
for dtype in all_dtypes))
def testItemsize(self, shape, dtype):
rng = jtu.rand_default(self.rng())
np_op = lambda x: np.asarray(x).itemsize
jnp_op = lambda x: jnp.asarray(x).itemsize
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_dtype={}".format(
jtu.format_shape_dtype_string(shape, a_dtype), dtype),
"shape": shape, "a_dtype": a_dtype, "dtype": dtype}
for shape in [(8,), (3, 8)] # last dim = 8 to ensure shape compatibility
for a_dtype in (default_dtypes + unsigned_dtypes + bool_dtypes)
for dtype in (default_dtypes + unsigned_dtypes + bool_dtypes)))
def testView(self, shape, a_dtype, dtype):
if jtu.device_under_test() == 'tpu':
if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]:
self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.")
if not config.x64_enabled:
if jnp.dtype(a_dtype).itemsize == 8 or jnp.dtype(dtype).itemsize == 8:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_fullrange(self.rng())
args_maker = lambda: [rng(shape, a_dtype)]
np_op = lambda x: np.asarray(x).view(dtype)
jnp_op = lambda x: jnp.asarray(x).view(dtype)
# Above may produce signaling nans; ignore warnings from invalid values.
with np.errstate(invalid='ignore'):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testPathologicalFloats(self):
args_maker = lambda: [np.array([
0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf
0b_1111_1111_1000_0000_0000_0000_0000_0000, # -inf
0b_0111_1111_1100_0000_0000_0000_0000_0000, # qnan
0b_1111_1111_1100_0000_0000_0000_0000_0000, # -qnan
0b_0111_1111_1000_0000_0000_0000_0000_0001, # snan
0b_1111_1111_1000_0000_0000_0000_0000_0001, # -snan
0b_0111_1111_1000_0000_0000_1100_0000_0000, # nonstandard nan
0b_1111_1111_1000_0000_0000_1100_0000_0000, # -nonstandard nan
0b_0000_0000_0000_0000_0000_0000_0000_0000, # zero
0b_1000_0000_0000_0000_0000_0000_0000_0000, # -zero
], dtype='uint32')]
np_op = lambda x: np.asarray(x).view('float32').view('uint32')
jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32')
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
# TODO(mattjj): test other ndarray-like method overrides
def testNpMean(self):
# from https://github.com/google/jax/issues/125
x = jnp.eye(3, dtype=float) + 0.
ans = np.mean(x)
self.assertAllClose(ans, np.array(1./3), check_dtypes=False)
def testArangeOnFloats(self):
np_arange = jtu.with_jax_dtype_defaults(np.arange)
# from https://github.com/google/jax/issues/145
self.assertAllClose(np_arange(0.0, 1.0, 0.1),
jnp.arange(0.0, 1.0, 0.1))
# from https://github.com/google/jax/issues/3450
self.assertAllClose(np_arange(2.5),
jnp.arange(2.5))
def testArangeTypes(self):
# Test that arange() output type is equal to the default types.
int_ = dtypes.canonicalize_dtype(jnp.int_)
float_ = dtypes.canonicalize_dtype(jnp.float_)
self.assertEqual(jnp.arange(10).dtype, int_)
self.assertEqual(jnp.arange(10.).dtype, float_)
self.assertEqual(jnp.arange(10, dtype='uint16').dtype, np.uint16)
self.assertEqual(jnp.arange(10, dtype='bfloat16').dtype, jnp.bfloat16)
self.assertEqual(jnp.arange(0, 10, 1).dtype, int_)
self.assertEqual(jnp.arange(0, 10, 1.).dtype, float_)
self.assertEqual(jnp.arange(0., 10, 1).dtype, float_)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for dtype in all_dtypes
for shape in nonzerodim_shapes
for axis in (None, *range(len(shape)))))
def testSort(self, dtype, shape, axis):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = jnp.sort
np_fun = np.sort
if axis is not None:
jnp_fun = partial(jnp_fun, axis=axis)
np_fun = partial(np_fun, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for dtype in all_dtypes
for shape in one_dim_array_shapes
for axis in [None]))
def testSortComplex(self, dtype, shape, axis):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp.sort_complex, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_input_type={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype),
input_type.__name__, axis),
"shape": shape, "dtype": dtype, "input_type": input_type, "axis": axis}
for dtype in all_dtypes
for shape in nonempty_nonscalar_array_shapes
for input_type in [np.array, tuple]
for axis in (-1, *range(len(shape) - 1))))
def testLexsort(self, dtype, shape, input_type, axis):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [input_type(rng(shape, dtype))]
jnp_op = lambda x: jnp.lexsort(x, axis=axis)
np_op = jtu.with_jax_dtype_defaults(lambda x: np.lexsort(x, axis=axis))
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for dtype in all_dtypes
for shape in nonzerodim_shapes
for axis in (None, *range(len(shape)))))
def testArgsort(self, dtype, shape, axis):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = jnp.argsort
np_fun = jtu.with_jax_dtype_defaults(np.argsort)
if axis is not None:
jnp_fun = partial(jnp_fun, axis=axis)
np_fun = partial(np_fun, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for dtype in all_dtypes
for shape in nonzerodim_shapes))
def testMsort(self, dtype, shape):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker)
self._CompileAndCheck(jnp.msort, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_shifts={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype),
shifts, axis),
"shape": shape, "dtype": dtype, "shifts": shifts, "axis": axis}
for dtype in all_dtypes
for shape in [(3, 4), (3, 4, 5), (7, 4, 0)]
for shifts, axis in [
(3, None),
(1, 1),
((3,), (0,)),
((-2,), (-2,)),
((1, 2), (0, -1)),
((4, 2, 5, 5, 2, 4), None),
(100, None),
]))
def testRoll(self, shape, dtype, shifts, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), np.array(shifts)]
jnp_op = partial(jnp.roll, axis=axis)
np_op = partial(np.roll, axis=axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_start={}".format(
jtu.format_shape_dtype_string(shape, dtype),
axis, start),
"shape": shape, "dtype": dtype, "axis": axis,
"start": start}
for dtype in all_dtypes
for shape in [(1, 2, 3, 4)]
for axis in [-3, 0, 2, 3]
for start in [-4, -1, 2, 4]))
def testRollaxis(self, shape, dtype, start, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.rollaxis, axis=axis, start=start)
np_op = partial(np.rollaxis, axis=axis, start=start)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_bitorder={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, bitorder),
"shape": shape, "dtype": dtype, "axis": axis,
"bitorder": bitorder}
for dtype in [np.uint8, np.bool_]
for bitorder in ['big', 'little']
for shape in [(1, 2, 3, 4)]
for axis in [None, 0, 1, -2, -1]))
def testPackbits(self, shape, dtype, axis, bitorder):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
np_op = partial(np.packbits, axis=axis, bitorder=bitorder)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_bitorder={}_count={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, bitorder, count),
"shape": shape, "dtype": dtype, "axis": axis, "bitorder": bitorder,
"count": count}
for dtype in [np.uint8]
for bitorder in ['big', 'little']
for shape in [(1, 2, 3, 4)]
for axis in [None, 0, 1, -2, -1]
for count in [None, 20]))
def testUnpackbits(self, shape, dtype, axis, bitorder, count):
rng = jtu.rand_int(self.rng(), 0, 256)
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_index={}_axis={}_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype),
jtu.format_shape_dtype_string(index_shape, index_dtype),
axis, mode),
"shape": shape, "index_shape": index_shape, "dtype": dtype,
"index_dtype": index_dtype, "axis": axis, "mode": mode}
for shape in [(3,), (3, 4), (3, 4, 5)]
for index_shape in scalar_shapes + [(3,), (2, 1, 3)]
for axis in itertools.chain(range(-len(shape), len(shape)),
[cast(Optional[int], None)])
for dtype in all_dtypes
for index_dtype in int_dtypes
for mode in [None, 'wrap', 'clip']))
def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode):
def args_maker():
x = rng(shape, dtype)
i = rng_indices(index_shape, index_dtype)
return x, i
rng = jtu.rand_default(self.rng())
if mode is None:
rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0])
else:
rng_indices = jtu.rand_int(self.rng(), -5, 5)
jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode)
np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testTakeEmpty(self):
np.testing.assert_array_equal(
jnp.array([], dtype=jnp.float32),
jnp.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32)))
np.testing.assert_array_equal(
jnp.ones((2, 0, 4), dtype=jnp.float32),
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32),
axis=1))
with self.assertRaisesRegex(IndexError, "non-empty jnp.take"):
jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32),
jnp.array([0], jnp.int32), axis=1)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_index={}_axis={}".format(
jtu.format_shape_dtype_string(x_shape, dtype),
jtu.format_shape_dtype_string(i_shape, index_dtype), axis),
"x_shape": x_shape, "i_shape": i_shape, "dtype": dtype,
"index_dtype": index_dtype, "axis": axis}
for x_shape, i_shape in filter(
_shapes_are_equal_length,
filter(_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2)))
for axis in itertools.chain(range(len(x_shape)), [-1],
[cast(Optional[int], None)])
for dtype in default_dtypes
for index_dtype in int_dtypes))
def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis):
rng = jtu.rand_default(self.rng())
i_shape = np.array(i_shape)
if axis is None:
i_shape = [np.prod(i_shape, dtype=np.int64)]
else:
# Test the case where the size of the axis doesn't necessarily broadcast.
i_shape[axis] *= 3
i_shape = list(i_shape)
def args_maker():
x = rng(x_shape, dtype)
n = np.prod(x_shape, dtype=np.int32) if axis is None else x_shape[axis]
if np.issubdtype(index_dtype, np.unsignedinteger):
index_rng = jtu.rand_int(self.rng(), 0, n)
else:
index_rng = jtu.rand_int(self.rng(), -n, n)
i = index_rng(i_shape, index_dtype)
return x, i
jnp_op = lambda x, i: jnp.take_along_axis(x, i, axis=axis)
if hasattr(np, "take_along_axis"):
np_op = lambda x, i: np.take_along_axis(x, i, axis=axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self):
# https://github.com/google/jax/issues/5088
h = jtu.rand_default(self.rng())((256, 256, 100), np.float32)
g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-1)
q1 = np.take_along_axis( h, g, axis=-1)
np.testing.assert_equal(q0, q1)
def testTakeAlongAxisOutOfBounds(self):
x = jnp.arange(10, dtype=jnp.float32)
idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11])
out = jnp.take_along_axis(x, idx, axis=0)
expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32)
np.testing.assert_array_equal(expected_clip, out)
out = jnp.take_along_axis(x, idx, axis=0, mode="clip")
np.testing.assert_array_equal(expected_clip, out)
expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan,
jnp.nan], np.float32)
out = jnp.take_along_axis(x, idx, axis=0, mode="fill")
np.testing.assert_array_equal(expected_fill, out)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}_increasing={}".format(
jtu.format_shape_dtype_string([shape], dtype),
n, increasing),
"dtype": dtype, "shape": shape, "n": n, "increasing": increasing}
for dtype in inexact_dtypes
for shape in [0, 5]
for n in [2, 4]
for increasing in [False, True]))
def testVander(self, shape, dtype, n, increasing):
rng = jtu.rand_default(self.rng())
def np_fun(arg):
arg = arg.astype(np.float32) if dtype == jnp.bfloat16 else arg
return np.vander(arg, N=n, increasing=increasing)
jnp_fun = lambda arg: jnp.vander(arg, N=n, increasing=increasing)
args_maker = lambda: [rng([shape], dtype)]
# np.vander seems to return float64 for all floating types. We could obey
# those semantics, but they seem like a bug.
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol={np.float32: 1e-3})
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
"nan_to_num", [shape], [dtype]),
"shape": shape, "dtype": dtype}
for shape in array_shapes
for dtype in inexact_dtypes))
def testNanToNum(self, shape, dtype):
rng = jtu.rand_some_inf_and_nan(self.rng())
dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type
def np_fun(x):
if dtype == jnp.bfloat16:
x = np.where(np.isnan(x), dtype(0), x)
x = np.where(np.isposinf(x), jnp.finfo(dtype).max, x)
x = np.where(np.isneginf(x), jnp.finfo(dtype).min, x)
return x
else:
return np.nan_to_num(x).astype(dtype)
args_maker = lambda: [rng(shape, dtype)]
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
self._CheckAgainstNumpy(np_fun, jnp.nan_to_num, args_maker,
check_dtypes=check_dtypes)
self._CompileAndCheck(jnp.nan_to_num, args_maker,
check_dtypes=check_dtypes)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes),
"shapes": shapes, "dtypes": dtypes}
for shapes, dtypes in (
((), ()),
(((7,),), (np.int32,)),
(((3,), (4,)), (np.int32, np.int32)),
(((3,), (1,), (4,)), (np.int32, np.int32, np.int32)),
)))
def testIx_(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)
for shape, dtype in zip(shapes, dtypes)]
self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker)
self._CompileAndCheck(jnp.ix_, args_maker)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_dimensions={}_dtype={}_sparse={}".format(
dimensions, dtype, sparse),
"dimensions": dimensions, "dtype": dtype, "sparse": sparse}
for dimensions in [(), (2,), (3, 0), (4, 5, 6)]
for dtype in number_dtypes
for sparse in [True, False]))
def testIndices(self, dimensions, dtype, sparse):
def args_maker(): return []
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
jnp_fun = partial(jnp.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}_method={}".format(
op,
jtu.format_shape_dtype_string(a_shape, a_dtype),
jtu.format_shape_dtype_string(q_shape, q_dtype),
axis, keepdims, method),
"a_rng": jtu.rand_some_nan,
"q_rng": q_rng, "op": op,
"a_shape": a_shape, "a_dtype": a_dtype,
"q_shape": q_shape, "q_dtype": q_dtype, "axis": axis,
"keepdims": keepdims,
"method": method}
for (op, q_rng) in (
("percentile", partial(jtu.rand_uniform, low=0., high=100.)),
("quantile", partial(jtu.rand_uniform, low=0., high=1.)),
("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)),
("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)),
)
for a_dtype in default_dtypes
for a_shape, axis in (
((7,), None),
((47, 7), 0),
((47, 7), ()),
((4, 101), 1),
((4, 47, 7), (1, 2)),
((4, 47, 7), (0, 2)),
((4, 47, 7), (1, 0, 2)),
)
for q_dtype in [np.float32]
for q_shape in scalar_shapes + [(1,), (4,)]
for keepdims in [False, True]
for method in ['linear', 'lower', 'higher', 'nearest', 'midpoint']))
def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype,
axis, keepdims, method):
a_rng = a_rng(self.rng())
q_rng = q_rng(self.rng())
if "median" in op:
args_maker = lambda: [a_rng(a_shape, a_dtype)]
else:
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
def np_fun(*args):
args = [x if jnp.result_type(x) != jnp.bfloat16 else
np.asarray(x, np.float32) for x in args]
if numpy_version <= (1, 22):
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
interpolation=method)
else:
return getattr(np, op)(*args, axis=axis, keepdims=keepdims,
method=method)
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims,
method=method)
# TODO(phawkins): we currently set dtype=False because we aren't as
# aggressive about promoting to float64. It's not clear we want to mimic
# Numpy here.
tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6}
tol = max(jtu.tolerance(a_dtype, tol_spec),
jtu.tolerance(q_dtype, tol_spec))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
@unittest.skipIf(not config.jax_enable_x64, "test requires X64")
@unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision")
def testPercentilePrecision(self):
# Regression test for https://github.com/google/jax/issues/8513
x = jnp.float64([1, 2, 3, 4, 7, 10])
self.assertEqual(jnp.percentile(x, 50), 3.5)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_{}_a_shape={}_axis={}_keepdims={}".format(
op, jtu.format_shape_dtype_string(a_shape, a_dtype),
axis, keepdims),
"op": op, "a_shape": a_shape, "a_dtype": a_dtype,
"axis": axis,
"keepdims": keepdims}
for a_dtype in default_dtypes
for a_shape, axis in (
((7,), None),
((47, 7), 0),
((4, 101), 1),
)
for keepdims in [False, True]
for op in ["median", "nanmedian"]))
def testMedian(self, op, a_shape, a_dtype, axis, keepdims):
if op == "median":
a_rng = jtu.rand_default(self.rng())
else:
a_rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: [a_rng(a_shape, a_dtype)]
def np_fun(*args):
args = [x if jnp.result_type(x) != jnp.bfloat16 else
np.asarray(x, np.float32) for x in args]
return getattr(np, op)(*args, axis=axis, keepdims=keepdims)
jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims)
# TODO(phawkins): we currently set dtype=False because we aren't as
# aggressive about promoting to float64. It's not clear we want to mimic
# Numpy here.
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
tol = jtu.tolerance(a_dtype, tol_spec)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in all_shapes for dtype in all_dtypes))
def testWhereOneArgument(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = lambda x: np.where(x)
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np_fun)
jnp_fun = lambda x: jnp.where(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
# JIT compilation requires specifying a size statically. Full test of
# this behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "_{}".format("_".join(
jtu.format_shape_dtype_string(shape, dtype)
for shape, dtype in zip(shapes, dtypes))),
"shapes": shapes, "dtypes": dtypes
} for shapes in s(filter(_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 3)))
for dtypes in s(itertools.combinations_with_replacement(all_dtypes, 3)))))
def testWhereThreeArgument(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
def np_fun(cond, x, y):
return _promote_like_jnp(partial(np.where, cond))(x, y)
self._CheckAgainstNumpy(np_fun, jnp.where, args_maker)
self._CompileAndCheck(jnp.where, args_maker)
def testWhereScalarPromotion(self):
x = jnp.where(jnp.array([True, False]), 3,
jnp.ones((2,), dtype=jnp.float32))
self.assertEqual(x.dtype, np.dtype(np.float32))
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": jtu.format_test_name_suffix("", shapes, (np.bool_,) * n + dtypes),
"shapes": shapes, "dtypes": dtypes
} for n in s(range(1, 3))
for shapes in s(filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 2 * n + 1)))
for dtypes in s(itertools.combinations_with_replacement(all_dtypes, n + 1)))))
def testSelect(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
n = len(dtypes) - 1
def args_maker():
condlist = [rng(shape, np.bool_) for shape in shapes[:n]]
choicelist = [rng(shape, dtype)
for shape, dtype in zip(shapes[n:-1], dtypes[:n])]
default = rng(shapes[-1], dtypes[-1])
return condlist, choicelist, default
# TODO(phawkins): float32/float64 type mismatches
def np_fun(condlist, choicelist, default):
choicelist = [x if jnp.result_type(x) != jnp.bfloat16
else x.astype(np.float32) for x in choicelist]
dtype = jnp.result_type(default, *choicelist)
return np.select(condlist,
[np.asarray(x, dtype=dtype) for x in choicelist],
np.asarray(default, dtype=dtype))
self._CheckAgainstNumpy(np_fun, jnp.select, args_maker,
check_dtypes=False)
self._CompileAndCheck(jnp.select, args_maker,
rtol={np.float64: 1e-7, np.complex128: 1e-7})
def testIssue330(self):
x = jnp.full((1, 1), jnp.array([1])[0]) # doesn't crash
self.assertEqual(x[0, 0], 1)
def testScalarDtypePromotion(self):
orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype
jax_numpy_result = (1 + jnp.eye(1, dtype=jnp.float32)).dtype
self.assertEqual(orig_numpy_result, jax_numpy_result)
def testSymmetrizeDtypePromotion(self):
x = np.eye(3, dtype=np.float32)
orig_numpy_result = ((x + x.T) / 2).dtype
x = jnp.eye(3, dtype=jnp.float32)
jax_numpy_result = ((x + x.T) / 2).dtype
self.assertEqual(orig_numpy_result, jax_numpy_result)
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because
# introducing the convention 0 * inf = 0 leads to silently wrong results in
# some cases. See this comment for details:
# https://github.com/google/jax/issues/1052#issuecomment-514083352
# def testIssue347(self):
# # https://github.com/google/jax/issues/347
# def test_fail(x):
# x = jnp.sqrt(jnp.sum(x ** 2, axis=1))
# ones = jnp.ones_like(x)
# x = jnp.where(x > 0.5, x, ones)
# return jnp.sum(x)
# x = jnp.array([[1, 2], [3, 4], [0, 0]], dtype=jnp.float64)
# result = jax.grad(test_fail)(x)
# assert not np.any(np.isnan(result))
def testIssue453(self):
# https://github.com/google/jax/issues/453
a = np.arange(6) + 1
ans = jnp.reshape(a, (3, 2), order='F')
expected = np.reshape(a, (3, 2), order='F')
self.assertAllClose(ans, expected)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_dtype={}".format(op, dtype.__name__),
"dtype": dtype, "op": op}
for dtype in [int, float, bool, complex]
for op in ["atleast_1d", "atleast_2d", "atleast_3d"]))
def testAtLeastNdLiterals(self, dtype, op):
# Fixes: https://github.com/google/jax/issues/634
np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype])
jnp_fun = lambda arg: getattr(jnp, op)(arg)
args_maker = lambda: [dtype(2)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{
"testcase_name": "_shape={}_dtype={}_weights={}_minlength={}_length={}".format(
shape, dtype, weights, minlength, length
),
"shape": shape,
"dtype": dtype,
"weights": weights,
"minlength": minlength,
"length": length}
for shape in [(0,), (5,), (10,)]
for dtype in int_dtypes
for weights in [True, False]
for minlength in [0, 20]
for length in [None, 8]
))
def testBincount(self, shape, dtype, weights, minlength, length):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None))
def np_fun(x, *args):
x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero.
out = np.bincount(x, *args, minlength=minlength)
if length and length > out.size:
return np.pad(out, (0, length - out.size))
return out[:length]
jnp_fun = partial(jnp.bincount, minlength=minlength, length=length)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
if length is not None:
self._CompileAndCheck(jnp_fun, args_maker)
def testBincountNegative(self):
# Test that jnp.bincount ignores negative values.
x_rng = jtu.rand_int(self.rng(), -100, 100)
w_rng = jtu.rand_uniform(self.rng())
shape = (1000,)
x = x_rng(shape, 'int32')
w = w_rng(shape, 'float32')
xn = np.array(x)
xn[xn < 0] = 0
wn = np.array(w)
np_result = np.bincount(xn[xn >= 0], wn[xn >= 0])
jnp_result = jnp.bincount(x, w)
self.assertAllClose(np_result, jnp_result, check_dtypes=False)
@parameterized.named_parameters(*jtu.cases_from_list(
{"testcase_name": "_case={}".format(i),
"input": input}
for i, input in enumerate([
3,
[3],
[np.array(3)],
[np.array([3])],
[[np.array(3)]],
[[np.array([3])]],
[3, 4, 5],
[
[np.eye(2, dtype=np.int32) * 2, np.zeros((2, 3), dtype=np.int32)],
[np.ones((3, 2), dtype=np.int32), np.eye(3, dtype=np.int32) * 3],
],
[np.array([1, 2, 3]), np.array([2, 3, 4]), 10],
[np.ones((2, 2), dtype=np.int32), np.zeros((2, 2), dtype=np.int32)],
[[np.array([1, 2, 3])], [np.array([2, 3, 4])]],
])))
def testBlock(self, input):
args_maker = lambda: [input]
self._CheckAgainstNumpy(np.block, jnp.block, args_maker)
self._CompileAndCheck(jnp.block, args_maker)
def testLongLong(self):
self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7)))
@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype.*")
def testArange(self):
# test cases inspired by dask tests at
# https://github.com/dask/dask/blob/main/dask/array/tests/test_creation.py#L92
np_arange = jtu.with_jax_dtype_defaults(np.arange)
self.assertAllClose(jnp.arange(77),
np_arange(77))
self.assertAllClose(jnp.arange(2, 13),
np_arange(2, 13))
self.assertAllClose(jnp.arange(4, 21, 9),
np_arange(4, 21, 9))
self.assertAllClose(jnp.arange(53, 5, -3),
np_arange(53, 5, -3))
self.assertAllClose(jnp.arange(77, dtype=float),
np_arange(77, dtype=float))
self.assertAllClose(jnp.arange(2, 13, dtype=int),
np_arange(2, 13, dtype=int))
self.assertAllClose(jnp.arange(0, 1, -0.5),
np_arange(0, 1, -0.5))
self.assertRaises(TypeError, lambda: jnp.arange())
# test that jnp.arange(N) doesn't instantiate an ndarray
self.assertNotEqual(type(jnp.arange(77)), type(np.arange(77)))
self.assertEqual(type(jnp.arange(77)), type(lax.iota(np.int32, 77)))
# test that jnp.arange(N, dtype=int32) doesn't instantiate an ndarray
self.assertNotEqual(type(jnp.arange(77, dtype=jnp.int32)),
type(np.arange(77, dtype=np.int32)))
self.assertEqual(type(jnp.arange(77, dtype=jnp.int32)),
type(lax.iota(np.int32, 77)))
def testArangeJit(self):
ans = jax.jit(lambda: jnp.arange(5))()
expected = jtu.with_jax_dtype_defaults(np.arange)(5)
self.assertAllClose(ans, expected)
def testIssue830(self):
a = jnp.arange(4, dtype=jnp.complex64)
self.assertEqual(a.dtype, jnp.complex64)
def testIssue728(self):
assert jnp.allclose(jnp.eye(5000), np.eye(5000))
self.assertEqual(0, np.sum(jnp.eye(1050) - np.eye(1050)))
def testIssue746(self):
jnp.arange(12).reshape(3, 4) # doesn't crash
def testIssue764(self):
x = jnp.linspace(190, 200, 4)
f = jax.grad(lambda x: jnp.sum(jnp.tanh(x)))
# Expected values computed with autograd in float64 precision.
expected = np.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171,
7.66067839e-174], np.float64)
self.assertAllClose(f(x), expected, check_dtypes=False)
def testIssue776(self):
"""Tests that the scatter-add transpose rule instantiates symbolic zeros."""
def f(u):
y = jnp.ones(10).at[np.array([2, 4, 5])].add(u)
# The transpose rule for lax.tie_in returns a symbolic zero for its first
# argument.
return lax.tie_in(y, 7.)
self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,)))
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because this
# is a numerical stability issue that should be solved with a custom jvp rule
# of the sigmoid function being differentiated here, not by safe_mul.
# def testIssue777(self):
# x = jnp.linspace(-200, 0, 4, dtype=np.float32)
# f = jax.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x))))
# self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32))
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]),
"dtype": dtype, "op": op}
for dtype in float_dtypes
for op in ("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan",
"sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp",
"log", "expm1", "log1p")))
def testMathSpecialFloatValues(self, op, dtype):
np_op = getattr(np, op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(np_op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="divide by zero.*")(np_op)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="overflow.*")(np_op)
jnp_op = getattr(jnp, op)
dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type
for x in (np.nan, -np.inf, -100., -2., -1., 0., 1., 2., 100., np.inf,
jnp.finfo(dtype).max, np.sqrt(jnp.finfo(dtype).max),
np.sqrt(jnp.finfo(dtype).max) * 2.):
if (op in ("sin", "cos", "tan") and
jtu.device_under_test() == "tpu"):
continue # TODO(b/132196789): fix and reenable.
x = dtype(x)
expected = np_op(x)
actual = jnp_op(x)
tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7})
self.assertAllClose(expected, actual, atol=tol,
rtol=tol)
def testReductionOfOutOfBoundsAxis(self): # Issue 888
x = jnp.ones((3, 4))
self.assertRaises(ValueError, lambda: jnp.sum(x, axis=2))
def testIssue956(self):
self.assertRaises(TypeError, lambda: jnp.ndarray((1, 1)))
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
.format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims),
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
"ddof": ddof, "keepdims": keepdims}
for shape in [(5,), (10, 5)]
for dtype in all_dtypes
for out_dtype in inexact_dtypes
for axis in [None, 0, -1]
for ddof in [0, 1, 2]
for keepdims in [False, True]))
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.")
def np_fun(x):
out = np.var(x.astype(jnp.promote_types(np.float32, dtype)),
axis=axis, ddof=ddof, keepdims=keepdims)
return out.astype(out_dtype)
jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
np.float64: 1e-3, np.complex128: 1e-6})
if (jnp.issubdtype(dtype, jnp.complexfloating) and
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
else:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
atol=tol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
.format(shape, dtype, out_dtype, axis, ddof, keepdims),
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
"ddof": ddof, "keepdims": keepdims}
for shape in [(5,), (10, 5)]
for dtype in all_dtypes
for out_dtype in inexact_dtypes
for axis in [None, 0, -1]
for ddof in [0, 1, 2]
for keepdims in [False, True]))
def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
rng = jtu.rand_some_nan(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.")
def np_fun(x):
out = np.nanvar(x.astype(jnp.promote_types(np.float32, dtype)),
axis=axis, ddof=ddof, keepdims=keepdims)
return out.astype(out_dtype)
jnp_fun = partial(jnp.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
np.float64: 1e-3, np.complex128: 1e-6})
if (jnp.issubdtype(dtype, jnp.complexfloating) and
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
else:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
atol=tol)
def testNanStdGrad(self):
# Regression test for https://github.com/google/jax/issues/8128
x = jnp.arange(5.0).at[0].set(jnp.nan)
y = jax.grad(jnp.nanvar)(x)
self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75]))
z = jax.grad(jnp.nanstd)(x)
self.assertEqual(jnp.isnan(z).sum(), 0)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format(
shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights),
"shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype,"rowvar": rowvar, "ddof": ddof,
"bias": bias, "fweights": fweights, "aweights": aweights}
for shape in [(5,), (10, 5), (5, 10)]
for dtype in all_dtypes
for y_dtype in [None, dtype]
for rowvar in [True, False]
for y_shape in _get_y_shapes(y_dtype, shape, rowvar)
for bias in [True, False]
for ddof in [None, 2, 3]
for fweights in [True, False]
for aweights in [True, False]))
def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights):
rng = jtu.rand_default(self.rng())
wrng = jtu.rand_positive(self.rng())
wdtype = np.real(dtype(0)).dtype
wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1]
args_maker = lambda: [rng(shape, dtype),
rng(y_shape, y_dtype) if y_dtype else None,
wrng(wshape, int) if fweights else None,
wrng(wshape, wdtype) if aweights else None]
kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias)
np_fun = lambda m, y, f, a: np.cov(m, y, fweights=f, aweights=a, **kwargs)
jnp_fun = lambda m, y, f, a: jnp.cov(m, y, fweights=f, aweights=a, **kwargs)
tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5,
np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13}
tol = 7e-2 if jtu.device_under_test() == "tpu" else tol
tol = jtu.join_tolerance(tol, jtu.tolerance(dtype))
self._CheckAgainstNumpy(
np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol,
rtol=tol)
def testIssue967(self):
self.assertRaises(TypeError, lambda: jnp.zeros(1.5))
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_rowvar={}".format(
shape, dtype.__name__, rowvar),
"shape": shape, "dtype": dtype, "rowvar": rowvar}
for shape in [(5,), (10, 5), (3, 10)]
for dtype in number_dtypes
for rowvar in [True, False]))
def testCorrCoef(self, shape, dtype, rowvar):
rng = jtu.rand_default(self.rng())
def args_maker():
ok = False
while not ok:
x = rng(shape, dtype)
ok = not np.any(np.isclose(np.std(x), 0.0))
return (x,)
np_fun = partial(np.corrcoef, rowvar=rowvar)
np_fun = jtu.ignore_warning(
category=RuntimeWarning, message="invalid value encountered.*")(np_fun)
jnp_fun = partial(jnp.corrcoef, rowvar=rowvar)
tol = 1e-2 if jtu.device_under_test() == "tpu" else None
self._CheckAgainstNumpy(
np_fun, jnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype),
"None" if end_dtype is None else jtu.format_shape_dtype_string(end_shape, end_dtype),
"None" if begin_dtype is None else jtu.format_shape_dtype_string(begin_shape, begin_dtype)),
"shape": shape, "dtype": dtype, "end_shape": end_shape,
"end_dtype": end_dtype, "begin_shape": begin_shape,
"begin_dtype": begin_dtype}
for dtype in number_dtypes
for end_dtype in [None] + [dtype]
for begin_dtype in [None] + [dtype]
for shape in [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]
for begin_shape in (
[None] if begin_dtype is None
else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE])
for end_shape in (
[None] if end_dtype is None
else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE])))
def testEDiff1d(self, shape, dtype, end_shape, end_dtype, begin_shape,
begin_dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype),
(None if end_dtype is None else rng(end_shape, end_dtype)),
(None if begin_dtype is None else rng(begin_shape, begin_dtype))]
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testEDiff1dWithDtypeCast(self):
rng = jtu.rand_default(self.rng())
shape = jtu.NUMPY_SCALAR_SHAPE
dtype = jnp.float32
end_dtype = jnp.int32
args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)]
np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin)
jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shapes={}_dtype={}_indexing={}_sparse={}".format(
shapes, dtype, indexing, sparse),
"shapes": shapes, "dtype": dtype, "indexing": indexing,
"sparse": sparse}
for shapes in [(), (5,), (5, 3)]
for dtype in number_dtypes
for indexing in ['xy', 'ij']
for sparse in [True, False]))
def testMeshGrid(self, shapes, dtype, indexing, sparse):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes],
[dtype] * len(shapes))
np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse)
jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testMgrid(self):
# wrap indexer for appropriate dtype defaults.
np_mgrid = _indexer_with_default_outputs(np.mgrid)
assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0)
assertAllEqual(np_mgrid[:4], jnp.mgrid[:4])
assertAllEqual(np_mgrid[:4,], jnp.mgrid[:4,])
assertAllEqual(np_mgrid[:4], jax.jit(lambda: jnp.mgrid[:4])())
assertAllEqual(np_mgrid[:5, :5], jnp.mgrid[:5, :5])
assertAllEqual(np_mgrid[:3, :2], jnp.mgrid[:3, :2])
assertAllEqual(np_mgrid[1:4:2], jnp.mgrid[1:4:2])
assertAllEqual(np_mgrid[1:5:3, :5], jnp.mgrid[1:5:3, :5])
assertAllEqual(np_mgrid[:3, :2, :5], jnp.mgrid[:3, :2, :5])
assertAllEqual(np_mgrid[:3:2, :2, :5], jnp.mgrid[:3:2, :2, :5])
# Corner cases
assertAllEqual(np_mgrid[:], jnp.mgrid[:])
# When the step length is a complex number, because of float calculation,
# the values between jnp and np might slightly different.
atol = 1e-6
rtol = 1e-6
self.assertAllClose(np_mgrid[-1:1:5j],
jnp.mgrid[-1:1:5j],
atol=atol,
rtol=rtol)
self.assertAllClose(np_mgrid[3:4:7j],
jnp.mgrid[3:4:7j],
atol=atol,
rtol=rtol)
self.assertAllClose(np_mgrid[1:6:8j, 2:4],
jnp.mgrid[1:6:8j, 2:4],
atol=atol,
rtol=rtol)
# Non-integer steps
self.assertAllClose(np_mgrid[0:3.5:0.5],
jnp.mgrid[0:3.5:0.5],
atol=atol,
rtol=rtol)
self.assertAllClose(np_mgrid[1.3:4.2:0.3],
jnp.mgrid[1.3:4.2:0.3],
atol=atol,
rtol=rtol)
# abstract tracer value for jnp.mgrid slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
"slice start of jnp.mgrid"):
jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2)
def testOgrid(self):
# wrap indexer for appropriate dtype defaults.
np_ogrid = _indexer_with_default_outputs(np.ogrid)
def assertListOfArraysEqual(xs, ys):
self.assertIsInstance(xs, list)
self.assertIsInstance(ys, list)
self.assertEqual(len(xs), len(ys))
for x, y in zip(xs, ys):
self.assertArraysEqual(x, y)
self.assertArraysEqual(np_ogrid[:5], jnp.ogrid[:5])
self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])())
self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2])
# List of arrays
assertListOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,])
assertListOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3])
assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3])
assertListOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11])
# Corner cases
self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:])
# Complex number steps
atol = 1e-6
rtol = 1e-6
self.assertAllClose(np_ogrid[-1:1:5j],
jnp.ogrid[-1:1:5j],
atol=atol,
rtol=rtol)
# Non-integer steps
self.assertAllClose(np_ogrid[0:3.5:0.3],
jnp.ogrid[0:3.5:0.3],
atol=atol,
rtol=rtol)
self.assertAllClose(np_ogrid[1.2:4.8:0.24],
jnp.ogrid[1.2:4.8:0.24],
atol=atol,
rtol=rtol)
# abstract tracer value for ogrid slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
"slice start of jnp.ogrid"):
jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2)
def testR_(self):
a = np.arange(6).reshape((2,3))
self.assertArraysEqual(np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])],
jnp.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])])
self.assertArraysEqual(np.r_['-1', a, a], jnp.r_['-1', a, a])
# wrap indexer for appropriate dtype defaults.
np_r_ = _indexer_with_default_outputs(np.r_)
self.assertArraysEqual(np_r_['0,2', [1,2,3], [4,5,6]], jnp.r_['0,2', [1,2,3], [4,5,6]])
self.assertArraysEqual(np_r_['0,2,0', [1,2,3], [4,5,6]], jnp.r_['0,2,0', [1,2,3], [4,5,6]])
self.assertArraysEqual(np_r_['1,2,0', [1,2,3], [4,5,6]], jnp.r_['1,2,0', [1,2,3], [4,5,6]])
# negative 1d axis start
self.assertArraysEqual(np_r_['0,4,-1', [1,2,3], [4,5,6]], jnp.r_['0,4,-1', [1,2,3], [4,5,6]])
self.assertArraysEqual(np_r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]])
# matrix directives
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
self.assertArraysEqual(np_r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]])
self.assertArraysEqual(np_r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]])
# bad directive
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
jnp.r_["asdfgh",[1,2,3]]
# abstract tracer value for r_ slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
"slice start of jnp.r_"):
jax.jit(lambda a, b: jnp.r_[a:b])(0, 2)
# Complex number steps
atol = 1e-6
rtol = 1e-6
self.assertAllClose(np_r_[-1:1:6j],
jnp.r_[-1:1:6j],
atol=atol,
rtol=rtol)
self.assertAllClose(np_r_[-1:1:6j, [0]*3, 5, 6],
jnp.r_[-1:1:6j, [0]*3, 5, 6],
atol=atol,
rtol=rtol)
# Non-integer steps
self.assertAllClose(np_r_[1.2:4.8:0.24],
jnp.r_[1.2:4.8:0.24],
atol=atol,
rtol=rtol)
def testC_(self):
a = np.arange(6).reshape((2, 3))
self.assertArraysEqual(np.c_[np.array([1,2,3]), np.array([4,5,6])],
jnp.c_[np.array([1,2,3]), np.array([4,5,6])])
self.assertArraysEqual(np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])],
jnp.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])])
self.assertArraysEqual(np.c_['-1', a, a], jnp.c_['-1', a, a])
# wrap indexer for appropriate dtype defaults.
np_c_ = _indexer_with_default_outputs(np.c_)
self.assertArraysEqual(np_c_['0,2', [1,2,3], [4,5,6]], jnp.c_['0,2', [1,2,3], [4,5,6]])
self.assertArraysEqual(np_c_['0,2,0', [1,2,3], [4,5,6]], jnp.c_['0,2,0', [1,2,3], [4,5,6]])
self.assertArraysEqual(np_c_['1,2,0', [1,2,3], [4,5,6]], jnp.c_['1,2,0', [1,2,3], [4,5,6]])
# negative 1d axis start
self.assertArraysEqual(np_c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]])
self.assertArraysEqual(np_c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]])
# matrix directives, avoid numpy deprecation warning
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
self.assertArraysEqual(np_c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]])
self.assertArraysEqual(np_c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]])
# bad directive
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
jnp.c_["asdfgh",[1,2,3]]
# abstract tracer value for c_ slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
"slice start of jnp.c_"):
jax.jit(lambda a, b: jnp.c_[a:b])(0, 2)
# Complex number steps
atol = 1e-6
rtol = 1e-6
self.assertAllClose(np_c_[-1:1:6j],
jnp.c_[-1:1:6j],
atol=atol,
rtol=rtol)
# Non-integer steps
self.assertAllClose(np_c_[1.2:4.8:0.24],
jnp.c_[1.2:4.8:0.24],
atol=atol,
rtol=rtol)
def testS_(self):
self.assertEqual(np.s_[1:2:20],jnp.s_[1:2:20])
def testIndex_exp(self):
self.assertEqual(np.index_exp[5:3:2j],jnp.index_exp[5:3:2j])
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": f"_start_shape={start_shape}_stop_shape={stop_shape}"
f"_num={num}_endpoint={endpoint}_retstep={retstep}"
f"_dtype={dtype.__name__ if dtype else 'None'}",
"start_shape": start_shape, "stop_shape": stop_shape,
"num": num, "endpoint": endpoint, "retstep": retstep,
"dtype": dtype}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
for retstep in [True, False]
# floating-point compute between jitted platforms and non-jit + rounding
# cause unavoidable variation in integer truncation for some inputs, so
# we currently only test inexact 'dtype' arguments.
for dtype in inexact_dtypes + [None,]))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype):
rng = jtu.rand_default(self.rng())
# relax default tolerances slightly
tol = jtu.tolerance(dtype if dtype else np.float32) * 10
args_maker = self._GetArgsMaker(rng,
[start_shape, stop_shape],
[dtype, dtype])
start, stop = args_maker()
ndim = len(np.shape(start + stop))
for axis in range(-ndim, ndim):
jnp_op = lambda start, stop: jnp.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
# NumPy 1.20.0 changed the semantics of linspace to floor for integer
# dtypes.
if numpy_version >= (1, 20) or not np.issubdtype(dtype, np.integer):
np_op = lambda start, stop: np.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
else:
def np_op(start, stop):
out = np.linspace(start, stop, num, endpoint=endpoint,
retstep=retstep, axis=axis)
if retstep:
return np.floor(out[0]).astype(dtype), out[1]
else:
return np.floor(out).astype(dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_op, args_maker,
check_dtypes=False, atol=tol, rtol=tol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
for dtype in number_dtypes))
def testLinspaceEndpoints(self, dtype):
"""Regression test for Issue #3014."""
rng = jtu.rand_default(self.rng())
endpoints = rng((2,), dtype)
out = jnp.linspace(*endpoints, 10, dtype=dtype)
self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
"_base={}_dtype={}").format(
start_shape, stop_shape, num, endpoint, base,
dtype.__name__ if dtype else "None"),
"start_shape": start_shape,
"stop_shape": stop_shape,
"num": num, "endpoint": endpoint, "base": base,
"dtype": dtype}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
for base in [10.0, 2, np.e]
for dtype in inexact_dtypes + [None,]))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLogspace(self, start_shape, stop_shape, num,
endpoint, base, dtype):
if (dtype in int_dtypes and
jtu.device_under_test() in ("gpu", "tpu") and
not config.x64_enabled):
raise unittest.SkipTest("GPUx32 truncated exponentiation"
" doesn't exactly match other platforms.")
rng = jtu.rand_default(self.rng())
# relax default tolerances slightly
tol = {np.float16: 2e-2, np.float32: 1e-2, np.float64: 1e-6,
np.complex64: 1e-3, np.complex128: 1e-6}
args_maker = self._GetArgsMaker(rng,
[start_shape, stop_shape],
[dtype, dtype])
start, stop = args_maker()
ndim = len(np.shape(start + stop))
for axis in range(-ndim, ndim):
jnp_op = lambda start, stop: jnp.logspace(
start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis)
@jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered in power")
def np_op(start, stop):
return np.logspace(start, stop, num, endpoint=endpoint,
base=base, dtype=dtype, axis=axis)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=False, tol=tol)
if dtype in (inexact_dtypes + [None,]):
# Why do compiled and op-by-op float16 np.power numbers differ
# slightly more than expected?
atol = {np.float16: 1e-2}
self._CompileAndCheck(jnp_op, args_maker,
check_dtypes=False, atol=atol, rtol=tol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
"_dtype={}_axis={}").format(
start_shape, stop_shape, num, endpoint,
dtype.__name__ if dtype else "None", axis),
"start_shape": start_shape,
"stop_shape": stop_shape,
"num": num, "endpoint": endpoint,
"dtype": dtype, "axis": axis}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
# NB: numpy's geomspace gives nonsense results on integer types
for dtype in inexact_dtypes + [None,]
for axis in range(-max(len(start_shape), len(stop_shape)),
max(len(start_shape), len(stop_shape)))))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testGeomspace(self, start_shape, stop_shape, num,
endpoint, dtype, axis):
rng = jtu.rand_default(self.rng())
# relax default tolerances slightly
tol = {np.float16: 4e-3, np.float32: 2e-3, np.float64: 1e-14,
np.complex128: 1e-14}
def args_maker():
"""Test the set of inputs np.geomspace is well-defined on."""
start, stop = self._GetArgsMaker(rng,
[start_shape, stop_shape],
[dtype, dtype])()
# np.geomspace can't handle differently ranked tensors
# w. negative numbers!
start, stop = jnp.broadcast_arrays(start, stop)
if dtype in complex_dtypes:
return start, stop
# to avoid NaNs, non-complex start and stop cannot
# differ in sign, elementwise
start = start * jnp.sign(start) * jnp.sign(stop)
return start, stop
start, stop = args_maker()
def jnp_op(start, stop):
return jnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype,
axis=axis)
def np_op(start, stop):
start = start.astype(np.float32) if dtype == jnp.bfloat16 else start
stop = stop.astype(np.float32) if dtype == jnp.bfloat16 else stop
return np.geomspace(
start, stop, num, endpoint=endpoint,
dtype=dtype if dtype != jnp.bfloat16 else np.float32,
axis=axis).astype(dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=False, tol=tol)
if dtype in (inexact_dtypes + [None,]):
self._CompileAndCheck(jnp_op, args_maker,
check_dtypes=False, atol=tol, rtol=tol)
def testDisableNumpyRankPromotionBroadcasting(self):
try:
prev_flag = config._read('jax_numpy_rank_promotion')
FLAGS.jax_numpy_rank_promotion = "allow"
jnp.ones(2) + jnp.ones((1, 2)) # works just fine
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
try:
prev_flag = config._read('jax_numpy_rank_promotion')
FLAGS.jax_numpy_rank_promotion = "raise"
self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2)))
jnp.ones(2) + 3 # don't want to raise for scalars
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
try:
prev_flag = config._read('jax_numpy_rank_promotion')
FLAGS.jax_numpy_rank_promotion = "warn"
self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on "
r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2)))
jnp.ones(2) + 3 # don't want to warn for scalars
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
@unittest.skip("Test fails on CI, perhaps due to JIT caching")
def testDisableNumpyRankPromotionBroadcastingDecorator(self):
with jax.numpy_rank_promotion("allow"):
jnp.ones(2) + jnp.ones((1, 2)) # works just fine
with jax.numpy_rank_promotion("raise"):
self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2)))
jnp.ones(2) + 3 # don't want to raise for scalars
with jax.numpy_rank_promotion("warn"):
self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on "
r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2)))
jnp.ones(2) + 3 # don't want to warn for scalars
def testStackArrayArgument(self):
# tests https://github.com/google/jax/issues/1271
@jax.jit
def foo(x):
return jnp.stack(x)
foo(np.zeros(2)) # doesn't crash
@jax.jit
def foo(x):
return jnp.concatenate(x)
foo(np.zeros((2, 2))) # doesn't crash
def testReluGradientConstants(self):
# This is a regression test that verifies that constants associated with the
# gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the
# outermost jaxpr. This was producing some large materialized constants for
# every relu activation in a model.
def body(i, xy):
x, y = xy
y = y + jax.grad(lambda z: jnp.sum(jnp.maximum(z, 0.)))(x)
return x, y
f = lambda y: lax.fori_loop(0, 5, body, (y, y))
jaxpr = jax.make_jaxpr(f)(np.zeros((3, 4), np.float32))
self.assertFalse(
any(np.array_equal(x, np.full((3, 4), 2., dtype=np.float32))
for x in jaxpr.consts))
@parameterized.named_parameters(
{"testcase_name": "_from={}_to={}".format(from_shape, to_shape),
"from_shape": from_shape, "to_shape": to_shape}
for from_shape, to_shape in [
[(1, 3), (4, 3)],
[(3,), (2, 1, 3)],
[(3,), (3, 3)],
[(1,), (3,)],
[(1,), 3],
])
def testBroadcastTo(self, from_shape, to_shape):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32])
np_op = lambda x: np.broadcast_to(x, to_shape)
jnp_op = lambda x: jnp.broadcast_to(x, to_shape)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": f"_{shapes}", "shapes": shapes, "broadcasted_shape": broadcasted_shape}
for shapes, broadcasted_shape in [
[[], ()],
[[()], ()],
[[(1, 3), (4, 3)], (4, 3)],
[[(3,), (2, 1, 3)], (2, 1, 3)],
[[(3,), (3, 3)], (3, 3)],
[[(1,), (3,)], (3,)],
[[(1,), 3], (3,)],
[[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
[[[1], [0, 1]], (0, 1)],
[[(1,), np.array([0, 1])], (0, 1)],
])
def testBroadcastShapes(self, shapes, broadcasted_shape):
# Test against np.broadcast_shapes once numpy 1.20 is minimum required version
np.testing.assert_equal(jnp.broadcast_shapes(*shapes), broadcasted_shape)
def testBroadcastToIssue1522(self):
self.assertRaisesRegex(
ValueError, "Incompatible shapes for broadcasting: .*",
lambda: jnp.broadcast_to(np.ones((2, 3)), (1, 3)))
def testBroadcastToIntIssue1548(self):
self.assertAllClose(jnp.broadcast_to(1, (3, 2)), np.ones((3, 2)),
check_dtypes=False)
def testBroadcastToOnScalar(self):
self.assertIsInstance(jnp.broadcast_to(10.0, ()), jnp.ndarray)
self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray)
def testPrecision(self):
ones_1d = np.ones((2,))
ones_2d = np.ones((2, 2))
ones_3d = np.ones((2, 2, 2))
HIGHEST = lax.Precision.HIGHEST
jtu.assert_dot_precision(None, jnp.dot, ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.dot, precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.dot, precision=HIGHEST),
ones_3d, ones_3d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.matmul, precision=HIGHEST),
ones_2d, ones_2d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.vdot, precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.tensordot, axes=2, precision=HIGHEST),
ones_2d, ones_2d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.tensordot, axes=(0, 0), precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.tensordot, axes=((0,), (0,)), precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.einsum, 'i,i', precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.einsum, 'ij,ij', precision=HIGHEST),
ones_2d, ones_2d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.inner, precision=HIGHEST),
ones_1d, ones_1d)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shape={}_varargs={} axis={}_dtype={}".format(
shape, varargs, axis, dtype),
"shape": shape, "varargs": varargs, "axis": axis, "dtype": dtype}
for shape in [(10,), (10, 15), (10, 15, 20)]
for _num_axes in range(len(shape))
for varargs in itertools.combinations(range(1, len(shape) + 1), _num_axes)
for axis in itertools.combinations(range(len(shape)), _num_axes)
for dtype in inexact_dtypes))
def testGradient(self, shape, varargs, axis, dtype):
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
jnp_fun = lambda y: jnp.gradient(y, *varargs, axis=axis)
np_fun = lambda y: np.gradient(y, *varargs, axis=axis)
self._CheckAgainstNumpy(
np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
def testZerosShapeErrors(self):
# see https://github.com/google/jax/issues/1822
self.assertRaisesRegex(
TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*",
lambda: jnp.zeros(1.))
self.assertRaisesRegex(
TypeError,
r"Shapes must be 1D sequences of concrete values of integer type.*\n"
"If using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.",
lambda: jax.jit(jnp.zeros)(2))
def testTraceMethod(self):
x = self.rng().randn(3, 4).astype(jnp.float_)
self.assertAllClose(x.trace(), jnp.array(x).trace())
self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x))
def testIntegerPowersArePrecise(self):
# See https://github.com/google/jax/pull/3036
# Checks if the squares of float32 integers have no numerical errors.
# It should be satisfied with all integers less than sqrt(2**24).
x = jnp.arange(-2**12, 2**12, dtype=jnp.int32)
np.testing.assert_array_equal(jnp.square(x.astype(jnp.float32)), x * x)
np.testing.assert_array_equal(x.astype(jnp.float32) ** 2, x * x)
# Similarly for cubes.
x = jnp.arange(-2**8, 2**8, dtype=jnp.int32)
np.testing.assert_array_equal(x.astype(jnp.float32) ** 3, x * x * x)
x = np.arange(10, dtype=np.float32)
for i in range(10):
self.assertAllClose(x.astype(jnp.float32) ** i, x ** i,
check_dtypes=False)
def testToBytes(self):
v = np.arange(12, dtype=np.int32).reshape(3, 4)
for order in ['C', 'F']:
self.assertEqual(jnp.asarray(v).tobytes(order), v.tobytes(order))
def testToList(self):
v = np.arange(12, dtype=np.int32).reshape(3, 4)
self.assertEqual(jnp.asarray(v).tolist(), v.tolist())
def testReductionWithRepeatedAxisError(self):
with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"):
jnp.sum(jnp.arange(3), (0, 0))
def testArangeConcretizationError(self):
msg = r"It arose in jax.numpy.arange argument `{}`".format
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
jax.jit(jnp.arange)(3)
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('start')):
jax.jit(lambda start: jnp.arange(start, 3))(0)
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
jax.jit(lambda stop: jnp.arange(0, stop))(3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": str(dtype), "dtype": dtype}
for dtype in [None] + float_dtypes))
def testArange64Bit(self, dtype):
# Test that jnp.arange uses 64-bit arithmetic to define its range, even if the
# output has another dtype. The issue here is that if python scalar inputs to
# jnp.arange are cast to float32 before the range is computed, it changes the
# number of elements output by the range. It's unclear whether this was deliberate
# behavior in the initial implementation, but it's behavior that downstream users
# have come to rely on.
args = (1.2, 4.8, 0.24)
# Ensure that this test case leads to differing lengths if cast to float32.
self.assertLen(np.arange(*args), 15)
self.assertLen(np.arange(*map(np.float32, args)), 16)
jnp_fun = lambda: jnp.arange(*args, dtype=dtype)
np_fun = jtu.with_jax_dtype_defaults(lambda: np.arange(*args, dtype=dtype), dtype is None)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testIssue2347(self):
# https://github.com/google/jax/issues/2347
object_list = List[Tuple[jnp.array, float, float, jnp.array, bool]]
self.assertRaises(TypeError, jnp.array, object_list)
np_object_list = np.array(object_list)
self.assertRaises(TypeError, jnp.array, np_object_list)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
"shapes": shapes, "dtypes": dtypes}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 2))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes))))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLogaddexpComplex(self, shapes, dtypes):
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
def np_op(x1, x2):
return np.log(np.exp(x1) + np.exp(x2))
rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes))
if jtu.device_under_test() == 'tpu':
tol = {np.complex64: 1e-3, np.complex128: 1e-10}
else:
tol = {np.complex64: 1e-5, np.complex128: 1e-14}
self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol)
self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
"shapes": shapes, "dtypes": dtypes}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(all_shapes, 2))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes))))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLogaddexp2Complex(self, shapes, dtypes):
@jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")
def np_op(x1, x2):
return np.log2(np.exp2(x1) + np.exp2(x2))
rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes))
if jtu.device_under_test() == 'tpu':
tol = {np.complex64: 1e-3, np.complex128: 1e-10}
else:
tol = {np.complex64: 1e-5, np.complex128: 1e-14}
self._CheckAgainstNumpy(_promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol)
self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol)
def testDefaultDtypes(self):
precision = config.jax_default_dtype_bits
assert precision in ['32', '64']
self.assertEqual(jnp.bool_, np.bool_)
self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64)
self.assertEqual(jnp.uint, np.uint32 if precision == '32' else np.uint64)
self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64)
self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128)
def testFromBuffer(self):
buf = b'\x01\x02\x03'
expected = np.frombuffer(buf, dtype='uint8')
actual = jnp.frombuffer(buf, dtype='uint8')
self.assertArraysEqual(expected, actual)
def testFromFunction(self):
def f(x, y, z):
return x + 2 * y + 3 * z
shape = (3, 4, 5)
expected = np.fromfunction(f, shape=shape)
actual = jnp.fromfunction(f, shape=shape)
self.assertArraysEqual(expected, actual)
def testFromString(self):
s = "1,2,3"
expected = np.fromstring(s, sep=',', dtype=int)
actual = jnp.fromstring(s, sep=',', dtype=int)
self.assertArraysEqual(expected, actual)
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.
GradTestSpec = collections.namedtuple(
"GradTestSpec",
["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"])
def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
return GradTestSpec(
op, nargs, order, rng_factory, dtypes, name or op.__name__, tol)
GRAD_TEST_RECORDS = [
grad_test_spec(jnp.arcsinh, nargs=1, order=2,
rng_factory=jtu.rand_positive,
dtypes=[np.float64, np.complex64],
tol={np.complex64: 2e-2}),
grad_test_spec(jnp.arccosh, nargs=1, order=2,
rng_factory=jtu.rand_positive,
dtypes=[np.float64, np.complex64],
tol={np.complex64: 2e-2}),
grad_test_spec(jnp.arctanh, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9),
dtypes=[np.float64, np.complex64],
tol={np.complex64: 2e-2}),
grad_test_spec(jnp.logaddexp, nargs=2, order=1,
rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9),
dtypes=[np.float64], tol=1e-4),
grad_test_spec(jnp.logaddexp2, nargs=2, order=2,
rng_factory=partial(jtu.rand_uniform, low=-0.9, high=0.9),
dtypes=[np.float64], tol=1e-4),
]
GradSpecialValuesTestSpec = collections.namedtuple(
"GradSpecialValuesTestSpec", ["op", "values", "order"])
GRAD_SPECIAL_VALUE_TEST_RECORDS = [
GradSpecialValuesTestSpec(jnp.arcsinh, [0., 1000.], 2),
GradSpecialValuesTestSpec(jnp.arccosh, [1000.], 2),
GradSpecialValuesTestSpec(jnp.arctanh, [0.], 2),
GradSpecialValuesTestSpec(jnp.sinc, [0.], 1),
]
class NumpyGradTests(jtu.JaxTestCase):
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.name, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
"order": rec.order, "tol": rec.tol}
for shapes in itertools.combinations_with_replacement(nonempty_shapes, rec.nargs)
for dtype in rec.dtypes)
for rec in GRAD_TEST_RECORDS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
rng = rng_factory(self.rng())
tol = jtu.join_tolerance(tol, {np.float32: 1e-1, np.float64: 1e-3,
np.complex64: 1e-1, np.complex128: 1e-3})
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
"op": rec.op, "special_value": special_value, "order": rec.order}
for special_value in rec.values)
for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS))
def testOpGradSpecialValue(self, op, special_value, order):
check_grads(op, (special_value,), order, ["fwd", "rev"],
atol={np.float32: 3e-3})
def testSincAtZero(self):
# Some manual tests for sinc at zero, since it doesn't have well-behaved
# numerical derivatives at zero
def deriv(f):
return lambda x: jax.jvp(f, (x,), (1.,))[1]
def apply_all(fns, x):
for f in fns:
x = f(x)
return x
d1 = 0.
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 1):
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d1)
d2 = -np.pi ** 2 / 3
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 2):
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d2)
d3 = 0.
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 3):
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d3)
d4 = np.pi ** 4 / 5
for ops in itertools.combinations_with_replacement([deriv, jax.grad], 4):
self.assertAllClose(apply_all(ops, jnp.sinc)(0.), d4)
def testSincGradArrayInput(self):
# tests for a bug almost introduced in #5077
jax.grad(lambda x: jnp.sinc(x).sum())(jnp.arange(10.)) # doesn't crash
def testTakeAlongAxisIssue1521(self):
# https://github.com/google/jax/issues/1521
idx = jnp.repeat(jnp.arange(3), 10).reshape((30, 1))
def f(x):
y = x * jnp.arange(3.).reshape((1, 3))
return jnp.take_along_axis(y, idx, -1).sum()
check_grads(f, (1.,), order=1)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, itertools.repeat(dtype)),
"shapes": shapes, "dtype": dtype}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(nonempty_shapes, 2))
for dtype in (np.complex128, )))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testGradLogaddexpComplex(self, shapes, dtype):
rng = jtu.rand_default(self.rng())
args = tuple(rng(shape, dtype) for shape in shapes)
if jtu.device_under_test() == "tpu":
tol = 5e-2
else:
tol = 3e-2
check_grads(jnp.logaddexp, args, 1, ["fwd", "rev"], tol, tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, itertools.repeat(dtype)),
"shapes": shapes, "dtype": dtype}
for shapes in filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(nonempty_shapes, 2))
for dtype in (np.complex128, )))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testGradLogaddexp2Complex(self, shapes, dtype):
rng = jtu.rand_default(self.rng())
args = tuple(rng(shape, dtype) for shape in shapes)
if jtu.device_under_test() == "tpu":
tol = 5e-2
else:
tol = 3e-2
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)
class NumpySignaturesTest(jtu.JaxTestCase):
def testWrappedSignaturesMatch(self):
"""Test that jax.numpy function signatures match numpy."""
jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)}
func_pairs = {name: (fun, fun.__np_wrapped__) for name, fun in jnp_funcs.items()
if hasattr(fun, '__np_wrapped__')}
assert len(func_pairs) > 0
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
unsupported_params = {
'asarray': ['like'],
'broadcast_to': ['subok', 'array'],
'clip': ['kwargs'],
'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'],
'cov': ['dtype'],
'empty_like': ['subok', 'order'],
'einsum': ['kwargs'],
'einsum_path': ['einsum_call'],
'eye': ['order', 'like'],
'identity': ['like'],
'full': ['order', 'like'],
'full_like': ['subok', 'order'],
'fromfunction': ['like'],
'histogram': ['normed'],
'histogram2d': ['normed'],
'histogramdd': ['normed'],
'ones': ['order', 'like'],
'ones_like': ['subok', 'order'],
'tri': ['like'],
'unwrap': ['period'],
'zeros_like': ['subok', 'order']
}
extra_params = {
'broadcast_to': ['arr'],
'einsum': ['precision'],
'einsum_path': ['subscripts'],
'take_along_axis': ['mode'],
}
mismatches = {}
for name, (jnp_fun, np_fun) in func_pairs.items():
# broadcast_shapes is not available in numpy < 1.20
if numpy_version < (1, 20) and name == "broadcast_shapes":
continue
# Some signatures have changed; skip for older numpy versions.
if numpy_version < (1, 19) and name in ['einsum_path', 'gradient', 'isscalar']:
continue
if numpy_version < (1, 22) and name in ['quantile', 'nanquantile',
'percentile', 'nanpercentile']:
continue
# Note: can't use inspect.getfullargspec due to numpy issue
# https://github.com/numpy/numpy/issues/12225
try:
np_params = inspect.signature(np_fun).parameters
except ValueError:
# Some functions cannot be inspected
continue
jnp_params = inspect.signature(jnp_fun).parameters
extra = set(extra_params.get(name, []))
unsupported = set(unsupported_params.get(name, []))
# Checks to prevent tests from becoming out-of-date. If these fail,
# it means that extra_params or unsupported_params need to be updated.
assert extra.issubset(jnp_params), f"{name}: extra={extra} is not a subset of jnp_params={set(jnp_params)}."
assert not unsupported.intersection(jnp_params), f"{name}: unsupported={unsupported} overlaps with jnp_params={set(jnp_params)}."
# Skip functions that only have *args and **kwargs; we can't introspect these further.
var_args = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
if all(p.kind in var_args for p in jnp_params.values()):
continue
if all(p.kind in var_args for p in np_params.values()):
continue
# Remove known extra parameters.
jnp_params = {a: p for a, p in jnp_params.items() if a not in extra}
# Remove known unsupported parameters.
np_params = {a: p for a, p in np_params.items() if a not in unsupported}
# Older versions of numpy may have fewer parameters; to avoid extraneous errors on older numpy
# versions, we allow for jnp to have more parameters.
if list(jnp_params)[:len(np_params)] != list(np_params):
mismatches[name] = {'np_params': list(np_params), 'jnp_params': list(jnp_params)}
self.assertEqual(mismatches, {})
_available_numpy_dtypes: List[str] = [dtype.__name__ for dtype in jtu.dtypes.all
if dtype != dtypes.bfloat16]
def _all_numpy_ufuncs() -> Iterator[str]:
"""Generate the names of all ufuncs in the top-level numpy namespace."""
for name in dir(np):
f = getattr(np, name)
if isinstance(f, np.ufunc):
yield name
def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]:
"""Generate valid dtypes of inputs to the given numpy ufunc."""
func = getattr(np, name)
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero", RuntimeWarning)
_ = func(*args)
except TypeError:
pass
else:
yield arg_dtypes
class NumpyUfuncTests(jtu.JaxTestCase):
@parameterized.named_parameters(
{"testcase_name": f"_{name}_{','.join(arg_dtypes)}",
"name": name, "arg_dtypes": arg_dtypes}
for name in _all_numpy_ufuncs()
for arg_dtypes in jtu.cases_from_list(_dtypes_for_ufunc(name)))
def testUfuncInputTypes(self, name, arg_dtypes):
# TODO(jakevdp): fix following failures and remove from this exception list.
if (name in ['divmod', 'floor_divide', 'fmod', 'gcd', 'left_shift', 'mod',
'power', 'remainder', 'right_shift', 'rint', 'square']
and 'bool_' in arg_dtypes):
self.skipTest(f"jax.numpy does not support {name}{tuple(arg_dtypes)}")
if name == 'arctanh' and jnp.issubdtype(arg_dtypes[0], jnp.complexfloating):
self.skipTest("np.arctanh & jnp.arctanh have mismatched NaNs for complex input.")
jnp_op = getattr(jnp, name)
np_op = getattr(np, name)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="divide by zero.*")(np_op)
args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes)
try:
jnp_op(*args_maker())
except NotImplementedError:
self.skipTest(f"jtu.{name} is not yet implemented.")
# large tol comes from the fact that numpy returns float16 in places
# that jnp returns float32. e.g. np.cos(np.uint8(0))
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2)
class NumpyDocTests(jtu.JaxTestCase):
def test_lax_numpy_docstrings(self):
# Test that docstring wrapping & transformation didn't fail.
# Functions that have their own docstrings & don't wrap numpy.
known_exceptions = {'broadcast_arrays', 'fromfile', 'fromiter', 'vectorize'}
for name in dir(jnp):
if name in known_exceptions or name.startswith('_'):
continue
# We only check signatures of functions.
obj = getattr(jnp, name)
if isinstance(obj, type) or not callable(obj):
continue
# Some jnp functions are imported from numpy or jax.dtypes directly.
if any(obj is getattr(mod, obj.__name__, None) for mod in [np, dtypes]):
continue
wrapped_fun = obj.__np_wrapped__
# If the wrapped function has a docstring, obj should too
if wrapped_fun.__doc__ and not obj.__doc__:
raise Exception(f"jnp.{name} does not contain wrapped docstring.")
if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__:
raise Exception(f"jnp.{name} does not have a wrapped docstring.")
@parameterized.named_parameters(
{"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False])
def test_wrapped_function_parameters(self, jit):
def orig(x):
"""Example Docstring
Parameters
----------
x : array_like
Input Data
.. versionadded:: 1.8.0
out : array_like, optional
Output to overwrite
other_arg : Any
not used
Returns
-------
x : input
"""
return x
def wrapped(x, out=None):
return x
if jit:
wrapped = jax.jit(wrapped)
wrapped = _wraps(orig, skip_params=['out'])(wrapped)
doc = wrapped.__doc__
self.assertStartsWith(doc, "Example Docstring")
self.assertIn("Original docstring below", doc)
self.assertIn("Parameters", doc)
self.assertIn("Returns", doc)
self.assertNotIn('out', doc)
self.assertNotIn('other_arg', doc)
self.assertNotIn('versionadded', doc)
def test_parse_numpydoc(self):
# Unit test ensuring that _parse_numpydoc correctly parses docstrings for all
# functions in NumPy's top-level namespace.
section_titles = {'Attributes', 'Examples', 'Notes',
'Parameters', 'Raises', 'References',
'Returns', 'See also', 'See Also', 'Warnings', 'Warns'}
headings = [title + '\n' + '-'*len(title) for title in section_titles]
for name in dir(np):
if name.startswith('_'):
continue
obj = getattr(np, name)
if isinstance(obj, type):
continue
if not callable(obj):
continue
if 'built-in function' in repr(obj):
continue
parsed = _parse_numpydoc(obj.__doc__)
# Check that no docstring is handled gracefully.
if not obj.__doc__:
self.assertEqual(parsed, ParsedDoc(obj.__doc__))
continue
# Check that no unexpected section names are found.
extra_keys = parsed.sections.keys() - section_titles
if extra_keys:
raise ValueError(f"Extra section headers found in np.{name}: {extra_keys}")
# Check that every docstring has a summary.
if not parsed.summary:
raise ValueError(f"No summary found for np.{name}")
# Check that no expected headings are missed.
for heading in headings:
assert heading not in parsed.front_matter
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())