rocm_jax/tests/lax_numpy_test.py
Tuan Nguyen 0ebf8488ae Implement np.flip with axis = None (#1783)
* super minimal starter code

* Update optimizers.py

* implement flip with axis = None
2019-11-28 11:54:29 -08:00

2536 lines
114 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from functools import partial
import itertools
import operator
import unittest
from unittest import SkipTest
import warnings
from absl.testing import absltest
from absl.testing import parameterized
import six
import numpy as onp
import jax
import jax.ops
from jax import api
from jax import lax
from jax import linear_util
from jax import numpy as lnp
from jax import test_util as jtu
from jax import dtypes
from jax.interpreters import partial_eval
from jax.test_util import check_grads
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
empty_array_shapes = [(0,), (0, 4), (3, 0),]
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
array_shapes = nonempty_array_shapes + empty_array_shapes
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
nonempty_shapes = scalar_shapes + nonempty_array_shapes
all_shapes = scalar_shapes + array_shapes
float_dtypes = list(jtu.supported_dtypes().intersection(
{lnp.bfloat16, onp.float16, onp.float32, onp.float64}))
complex_dtypes = [onp.complex64, onp.complex128]
int_dtypes = [onp.int32, onp.int64]
unsigned_dtypes = [onp.uint32, onp.uint64]
bool_dtypes = [onp.bool_]
default_dtypes = float_dtypes + int_dtypes
inexact_dtypes = float_dtypes + complex_dtypes
number_dtypes = float_dtypes + complex_dtypes + int_dtypes
all_dtypes = number_dtypes + bool_dtypes
python_scalar_dtypes = [lnp.bool_, lnp.int_, lnp.float_, lnp.complex_]
def _valid_dtypes_for_shape(shape, dtypes):
# Not all (shape, dtype) pairs are valid. In particular, Python scalars only
# have one type in each category (float, bool, etc.)
if shape is jtu.PYTHON_SCALAR_SHAPE:
return [t for t in dtypes if t in python_scalar_dtypes]
return dtypes
def _shape_and_dtypes(shapes, dtypes):
for shape in shapes:
for dtype in _valid_dtypes_for_shape(shape, dtypes):
yield (shape, dtype)
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
"test_name", "check_dtypes", "tolerance"])
def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name=None, check_dtypes=True, tolerance=None):
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name, check_dtypes, tolerance)
JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("add", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("float_power", 2, inexact_dtypes, all_shapes, jtu.rand_default, ["rev"],
tolerance={lnp.bfloat16: 1e-2, onp.float32: 1e-3,
onp.float64: 1e-12, onp.complex64: 2e-4,
onp.complex128: 1e-12}),
op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("greater", 2, number_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("greater_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("less", 2, number_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("less_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"]),
op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []),
op_record("maximum", 2, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"]),
op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("tan", 1, number_dtypes, all_shapes,
partial(jtu.rand_uniform, -1.5, 1.5), ["rev"]),
op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
# ~float32 precision.
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
tolerance={onp.float64: 1e-7, onp.complex128: 1e-7}),
op_record("arcsin", 1, float_dtypes, all_shapes, jtu.rand_small, ["rev"]),
op_record("arccos", 1, float_dtypes, all_shapes, jtu.rand_small, ["rev"]),
op_record("arctan", 1, float_dtypes, all_shapes, jtu.rand_small, ["rev"]),
op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small, ["rev"]),
op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"]),
op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"]),
op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"]),
]
JAX_COMPOUND_OP_RECORDS = [
# angle has inconsistent 32/64-bit return types across numpy versions.
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False),
op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"]),
op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes,
jtu.rand_nonzero, []),
op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
tolerance={lnp.bfloat16: 2e-2, onp.float16: 1e-2}),
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
# precision.
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
test_name="expm1_large", tolerance={onp.float64: 1e-8}),
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive,
[], tolerance={onp.float64: 1e-8}),
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"]),
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("isfinite", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isinf", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isnan", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"]),
op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"]),
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive, [],
test_name="log1p_large", tolerance={onp.float64: 1e-12}),
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [],
tolerance={onp.float64: 1e-12}),
op_record("logaddexp", 2, float_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"],
tolerance={onp.float64: 1e-12}),
op_record("logaddexp2", 2, float_dtypes, all_shapes,
jtu.rand_some_inf_and_nan, ["rev"],
tolerance={onp.float16: 1e-2}),
op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes,
jtu.rand_default, [], check_dtypes=False,
tolerance={onp.float16: 1e-2, onp.float64: 1e-12}),
op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"],
tolerance={onp.complex128: 1e-14}),
op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={onp.float16: 1e-2}),
op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("sinc", 1, [t for t in number_dtypes if t != lnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"],
tolerance={onp.complex64: 1e-5}),
op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"]),
op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"],
check_dtypes=False),
op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero, ["rev"]),
op_record("where", 3, (onp.float32, onp.int64), all_shapes, jtu.rand_some_zero, []),
op_record("diff", 1, number_dtypes, nonzerodim_shapes, jtu.rand_default, ["rev"]),
]
JAX_BITWISE_OP_RECORDS = [
op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_bool, []),
]
JAX_REDUCER_RECORDS = [
op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("nansum", 1, number_dtypes, all_shapes, jtu.rand_some_nan, []),
]
JAX_REDUCER_NO_DTYPE_RECORDS = [
op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, []),
]
JAX_ARGMINMAX_RECORDS = [
op_record("argmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
op_record("argmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
]
JAX_OPERATOR_OVERLOADS = [
op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []),
op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
tolerance={onp.float32: 2e-4, onp.complex64: 2e-4, onp.complex128: 1e-14}),
op_record("__mod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={onp.float16: 1e-1}),
op_record("__floordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []),
# TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2
op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []),
# TODO(mattjj): investigate these failures
# op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
# op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
# TODO(mattjj): lshift, rshift
]
JAX_RIGHT_OPERATOR_OVERLOADS = [
op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [],
tolerance={onp.float32: 2e-4, onp.complex64: 1e-3}),
op_record("__rmod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, [],
tolerance={onp.float16: 1e-1}),
op_record("__rfloordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []),
op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
# op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
# op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []),
# op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
]
numpy_version = tuple(map(int, onp.version.version.split('.')))
if numpy_version >= (1, 15):
JAX_COMPOUND_OP_RECORDS += [
op_record("isclose", 2, [t for t in all_dtypes if t != lnp.bfloat16],
all_shapes, jtu.rand_small_positive, []),
op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default, []),
op_record("lcm", 2, int_dtypes, all_shapes, jtu.rand_default, []),
]
JAX_REDUCER_NO_DTYPE_RECORDS += [
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
]
if six.PY2:
JAX_OPERATOR_OVERLOADS += [
op_record("__div__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
]
JAX_RIGHT_OPERATOR_OVERLOADS += [
op_record("__rdiv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []),
]
CombosWithReplacement = itertools.combinations_with_replacement
def _dtypes_are_compatible_for_bitwise_ops(args):
if len(args) <= 1:
return True
is_signed = lambda dtype: lnp.issubdtype(dtype, onp.signedinteger)
width = lambda dtype: lnp.iinfo(dtype).bits
x, y = args
if width(x) > width(y):
x, y = y, x
# The following condition seems a little ad hoc, but seems to capture what
# numpy actually implements.
return (
is_signed(x) == is_signed(y)
or (width(x) == 32 and width(y) == 32)
or (width(x) == 32 and width(y) == 64 and is_signed(y)))
def _shapes_are_broadcast_compatible(shapes):
accumulator = onp.zeros([])
for shape in shapes:
try:
accumulator = accumulator + onp.zeros(shape)
except ValueError:
return False
return True
def _shapes_are_equal_length(shapes):
return all(len(shape) == len(shapes[0]) for shape in shapes[1:])
class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""
def _GetArgsMaker(self, rng, shapes, dtypes):
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
"check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance}
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
JAX_COMPOUND_OP_RECORDS)))
def testOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes, check_dtypes,
tolerance):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
python_scalar = jtu.PYTHON_SCALAR_SHAPE in shapes
scalar_arg = (jtu.PYTHON_SCALAR_SHAPE in shapes or
jtu.NUMPY_SCALAR_SHAPE in shapes or
() in shapes)
empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes)
tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes)
tol = functools.reduce(jtu.join_tolerance,
[tolerance, tol, jtu.default_tolerance()])
self._CheckAgainstNumpy(
onp_op, lnp_op, args_maker,
check_dtypes=check_dtypes and not scalar_arg and not empty_shape,
tol=tol)
self._CompileAndCheck(lnp_op, args_maker,
check_dtypes=check_dtypes and not python_scalar,
atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
"tol": rec.tolerance}
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in JAX_OPERATOR_OVERLOADS))
def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
fun = lambda *xs: getattr(operator, name.strip('_'))(*xs)
scalar_arg = (jtu.PYTHON_SCALAR_SHAPE in shapes or
jtu.NUMPY_SCALAR_SHAPE in shapes or
() in shapes)
empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes)
self._CompileAndCheck(
fun, args_maker, check_dtypes=not scalar_arg and not empty_shape,
atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name,
"op_tolerance": rec.tolerance}
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes)))
for rec in JAX_RIGHT_OPERATOR_OVERLOADS))
def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes,
op_tolerance):
if shapes[1] is jtu.PYTHON_SCALAR_SHAPE:
raise SkipTest() # TODO(mattjj): clean up
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
fun = lambda fst, snd: getattr(snd, name)(fst)
tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes)
scalar_arg = (jtu.PYTHON_SCALAR_SHAPE in shapes or
jtu.NUMPY_SCALAR_SHAPE in shapes or
() in shapes)
empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes)
self._CompileAndCheck(
fun, args_maker, check_dtypes=not scalar_arg and not empty_shape,
atol=tol, rtol=tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.test_name, shapes, dtypes),
"rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in filter(
_dtypes_are_compatible_for_bitwise_ops,
CombosWithReplacement(rec.dtypes, rec.nargs)))
for rec in JAX_BITWISE_OP_RECORDS))
def testBitwiseOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes):
rng = rng_factory()
if not FLAGS.jax_enable_x64 and any(
lnp.iinfo(dtype).bits == 64 for dtype in dtypes):
self.skipTest("x64 types are disabled by jax_enable_x64")
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker,
check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis,
"None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
"axis": axis, "keepdims": keepdims}
for shape in rec.shapes for dtype in rec.dtypes
for out_dtype in [None] + rec.dtypes
for axis in set(range(-len(shape), len(shape))) | set([None])
for keepdims in [False, True])
for rec in JAX_REDUCER_RECORDS))
def testReducer(self, onp_op, lnp_op, rng_factory, shape, dtype, out_dtype, axis, keepdims):
rng = rng_factory()
def onp_fun(x):
x_cast = x if dtype != lnp.bfloat16 else x.astype(onp.float32)
t = out_dtype if out_dtype != lnp.bfloat16 else onp.float32
return onp_op(x_cast, axis, dtype=t, keepdims=keepdims)
lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
args_maker = lambda: [rng(shape, dtype)]
tol_spec = {onp.float16: 1e-2, onp.float32: 1e-3, onp.complex64: 1e-3,
onp.float64: 1e-5, onp.complex128: 1e-5}
tol = jtu.tolerance(dtype, tol_spec)
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker,
check_dtypes=lnp.bfloat16 not in (dtype, out_dtype),
tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
"axis": axis, "keepdims": keepdims}
for rec in JAX_REDUCER_NO_DTYPE_RECORDS
for shape in rec.shapes for dtype in rec.dtypes
for axis in set(range(-len(shape), len(shape))) | set([None])
for keepdims in [False, True]))
def testReducerNoDtype(self, onp_op, lnp_op, rng_factory, shape, dtype, axis, keepdims):
rng = rng_factory()
onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"shape": shape, "dtype": dtype, "axis": axis}
for shape in all_shapes for dtype in all_dtypes
for axis in set(range(-len(shape), len(shape))) | set([None])))
def testCountNonzero(self, shape, dtype, axis):
rng = jtu.rand_some_zero()
onp_fun = lambda x: onp.count_nonzero(x, axis)
lnp_fun = lambda x: lnp.count_nonzero(x, axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
"axis": axis}
for rec in JAX_ARGMINMAX_RECORDS
for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes)
for axis in range(-len(shape), len(shape))))
def testArgMinMax(self, onp_op, lnp_op, rng_factory, shape, dtype, axis):
rng = rng_factory()
if dtype == onp.complex128 and jtu.device_under_test() == "gpu":
raise unittest.SkipTest("complex128 reductions not supported on GPU")
def onp_fun(array_to_reduce):
return onp_op(array_to_reduce, axis).astype(lnp.int_)
def lnp_fun(array_to_reduce):
return lnp_op(array_to_reduce, axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
axes),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"axes": axes, "rng_factory": rng_factory}
for rng_factory in [jtu.rand_default]
for lhs_shape, rhs_shape, axes in [
[(2,), (2,), (-1, -1, -1, None)], # scalar output
[(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors
[(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors
[(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting
[(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes
[(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting
[(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors
[(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting
[(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing
[(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before
]
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory):
rng = rng_factory()
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
axisa, axisb, axisc, axis = axes
lnp_fun = lambda a, b: lnp.cross(a, b, axisa, axisb, axisc, axis)
def onp_fun(a, b):
a = a.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else a
b = b.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else b
out = onp.cross(a, b, axisa, axisb, axisc, axis)
return out.astype(lnp.promote_types(lhs_dtype, rhs_dtype))
tol_spec = {dtypes.bfloat16: 3e-1, onp.float16: 1e-2}
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
jtu.tolerance(rhs_dtype, tol_spec))
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
name,
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"rng_factory": rng_factory}
for rng_factory in [jtu.rand_default]
for name, lhs_shape, rhs_shape in [
("matrix-scalar", (3, 3), ()),
("scalar-matrix", (), (3, 3)),
("matrix-vector", (4, 5), (5,)),
("vector-matrix", (6,), (6, 4)),
("matrix-matrix", (3, 4), (4, 5)),
("tensor-vector", (4, 3, 2), (2,)),
("vector-tensor", (2,), (3, 2, 4)),
("tensor-matrix", (4, 3, 2), (2, 5)),
("matrix-tensor", (5, 2), (3, 2, 4)),
("tensor-tensor", (2, 3, 4), (5, 4, 1))]
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory):
rng = rng_factory()
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
tol = {onp.float16: 1e-2, onp.float32: 1e-5, onp.float64: 1e-14,
onp.complex128: 1e-14}
if jtu.device_under_test() == "tpu":
tol[onp.float32] = tol[onp.complex64] = 2e-1
def onp_dot(x, y):
x = x.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else x
y = y.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else y
return onp.dot(x, y).astype(lnp.promote_types(lhs_dtype, rhs_dtype))
self._CheckAgainstNumpy(onp_dot, lnp.dot, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
name,
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"rng_factory": rng_factory}
for rng_factory in [jtu.rand_default]
for name, lhs_shape, rhs_shape in [
("vector-vector", (3,), (3,)),
("matrix-vector", (3, 3), (3,)),
("vector-matrix", (3,), (3, 3)),
("matrix-matrix", (3, 3), (3, 3)),
("vector-tensor", (3,), (5, 3, 2)),
("tensor-vector", (5, 3, 2), (2,)),
("matrix-tensor", (5, 2), (3, 2, 4)),
("tensor-matrix", (5, 2, 3), (3, 2)),
("tensor-tensor", (5, 3, 4), (5, 4, 1)),
("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory):
rng = rng_factory()
def onp_fun(x, y):
dtype = lnp.promote_types(lhs_dtype, rhs_dtype)
return onp.matmul(x, y).astype(dtype)
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
tol = {onp.float16: 1e-2, onp.float32: 2e-2, onp.float64: 1e-12,
onp.complex128: 1e-12}
if jtu.device_under_test() == "tpu":
tol[onp.float32] = tol[onp.complex64] = 4e-2
self._CheckAgainstNumpy(onp_fun, lnp.matmul, args_maker,
check_dtypes=True, tol=tol)
self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
axes),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"axes": axes, "rng_factory": rng_factory}
for rng_factory in [jtu.rand_default]
for lhs_shape, rhs_shape, axes in [
[(2, 3, 4), (5, 6, 7), 0], # from issue #740
[(2, 3, 4), (3, 4, 5, 6), 2],
[(2, 3, 4), (5, 4, 3, 6), [1, 2]],
[(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
[(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
]
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory):
rng = rng_factory()
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
lnp_fun = lambda a, b: lnp.tensordot(a, b, axes)
def onp_fun(a, b):
a = a if lhs_dtype != lnp.bfloat16 else a.astype(onp.float32)
b = b if rhs_dtype != lnp.bfloat16 else b.astype(onp.float32)
dtype = lnp.promote_types(lhs_dtype, rhs_dtype)
return onp.tensordot(a, b, axes).astype(dtype)
tol = {onp.float16: 1e-1, onp.float32: 1e-3, onp.float64: 1e-12,
onp.complex64: 1e-3, onp.complex128: 1e-12}
if jtu.device_under_test() == "tpu":
tol[onp.float32] = tol[onp.complex64] = 2e-1
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"rng_factory": jtu.rand_default}
# TODO(phawkins): support integer dtypes too.
for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes)
for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes)
if len(jtu._dims_of_shape(lhs_shape)) == 0
or len(jtu._dims_of_shape(rhs_shape)) == 0
or lhs_shape[-1] == rhs_shape[-1]))
def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory):
rng = rng_factory()
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
def onp_fun(lhs, rhs):
lhs = lhs if lhs_dtype != lnp.bfloat16 else lhs.astype(onp.float32)
rhs = rhs if rhs_dtype != lnp.bfloat16 else rhs.astype(onp.float32)
dtype = lnp.promote_types(lhs_dtype, rhs_dtype)
return onp.inner(lhs, rhs).astype(dtype)
lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs)
tol_spec = {onp.float16: 1e-2, onp.float32: 1e-5, onp.float64: 1e-13}
if jtu.device_under_test() == "tpu":
tol_spec[onp.float32] = tol_spec[onp.complex64] = 2e-1
tol = max(jtu.tolerance(lhs_dtype, tol_spec),
jtu.tolerance(rhs_dtype, tol_spec))
# TODO(phawkins): there are float32/float64 disagreements for some inputs.
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False, atol=tol,
rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_amin={}_amax={}".format(
jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
"shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max,
"rng_factory": jtu.rand_default}
for shape in all_shapes for dtype in number_dtypes
for a_min, a_max in [(-1, None), (None, 1), (-1, 1),
(-onp.ones(1), None),
(None, onp.ones(1)),
(-onp.ones(1), onp.ones(1))]))
def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng_factory):
rng = rng_factory()
onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max)
args_maker = lambda: [rng(shape, dtype)]
# TODO(phawkins): the promotion behavior changed in Numpy 1.17.
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_decimals={}".format(
jtu.format_shape_dtype_string(shape, dtype), decimals),
"shape": shape, "dtype": dtype, "decimals": decimals,
"rng_factory": jtu.rand_default}
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)
for decimals in [0, 1, -2]))
def testRoundStaticDecimals(self, shape, dtype, decimals, rng_factory):
rng = rng_factory()
if lnp.issubdtype(dtype, onp.integer) and decimals < 0:
self.skipTest("Integer rounding with decimals < 0 not implemented")
onp_fun = lambda x: onp.round(x, decimals=decimals)
lnp_fun = lambda x: lnp.round(x, decimals=decimals)
args_maker = lambda: [rng(shape, dtype)]
tol = {lnp.bfloat16: 5e-2, onp.float16: 1e-2}
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=check_dtypes,
atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_rpadwidth={}_rconstantvalues={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width_rank,
constant_values_rank),
"shape": shape, "dtype": dtype, "mode": mode,
"pad_width_rank": pad_width_rank,
"constant_values_rank": constant_values_rank,
"rng_factory": jtu.rand_default,
"irng_factory": partial(jtu.rand_int, 3)}
for mode, constant_values_rank, shapes in [
('constant', 0, all_shapes),
('constant', 1, all_shapes),
('constant', 2, all_shapes),
('symmetric', None, nonempty_shapes),
('reflect', None, nonempty_shapes),
('wrap', None, nonempty_shapes),
]
for shape, dtype in _shape_and_dtypes(shapes, all_dtypes)
for pad_width_rank in range(3)))
def testPad(self, shape, dtype, mode, pad_width_rank, constant_values_rank,
rng_factory, irng_factory):
rng = rng_factory()
irng = irng_factory()
pad_width = irng([len(shape), 2][2 - pad_width_rank:], onp.int32)
def onp_fun(x, kwargs):
if pad_width.size == 0:
return x
return onp.pad(x, pad_width, mode=mode, **kwargs)
def lnp_fun(x, kwargs):
return lnp.pad(x, pad_width, mode=mode, **kwargs)
def args_maker():
kwargs = {}
if constant_values_rank:
kwargs["constant_values"] = rng(
[len(shape), 2][2 - constant_values_rank:], dtype)
return rng(shape, dtype), kwargs
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape=[{}]_reps={}".format(
jtu.format_shape_dtype_string(shape, dtype), reps),
"shape": shape, "dtype": dtype, "reps": reps,
"rng_factory": jtu.rand_default}
for reps in [(), (2,), (3, 4), (2, 3, 4)]
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)
))
def testTile(self, shape, dtype, reps, rng_factory):
rng = rng_factory()
onp_fun = lambda arg: onp.tile(arg, reps)
lnp_fun = lambda arg: lnp.tile(arg, reps)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
axis, ",".join(str(d) for d in base_shape),
",".join(onp.dtype(dtype).name for dtype in arg_dtypes)),
"axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes,
"rng_factory": jtu.rand_default}
for num_arrs in [3]
for arg_dtypes in CombosWithReplacement(default_dtypes, num_arrs)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(-len(base_shape)+1, len(base_shape))))
def testConcatenate(self, axis, base_shape, arg_dtypes, rng_factory):
rng = rng_factory()
wrapped_axis = axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
def onp_fun(*args):
args = [x if x.dtype != lnp.bfloat16 else x.astype(onp.float32)
for x in args]
dtype = functools.reduce(lnp.promote_types, arg_dtypes)
return onp.concatenate(args, axis=axis).astype(dtype)
lnp_fun = lambda *args: lnp.concatenate(args, axis=axis)
def args_maker():
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
axis, ",".join(str(d) for d in base_shape),
",".join(onp.dtype(dtype).name for dtype in arg_dtypes)),
"axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes,
"rng_factory": jtu.rand_default}
for arg_dtypes in CombosWithReplacement(default_dtypes, 2)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(-len(base_shape)+1, len(base_shape))))
def testAppend(self, axis, base_shape, arg_dtypes, rng_factory):
rng = rng_factory()
wrapped_axis = axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
def onp_fun(arr, values):
arr = arr.astype(onp.float32) if arr.dtype == lnp.bfloat16 else arr
values = (values.astype(onp.float32) if values.dtype == lnp.bfloat16
else values)
out = onp.append(arr, values, axis=axis)
return out.astype(lnp.promote_types(*arg_dtypes))
lnp_fun = lambda arr, values: lnp.append(arr, values, axis=axis)
def args_maker():
return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape=[{}]_axis={}_repeats={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, repeats),
"axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats,
"rng_factory": jtu.rand_default}
for repeats in [0, 1, 2]
for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)
for axis in [None] + list(range(-len(shape), len(shape)))))
def testRepeat(self, axis, shape, dtype, repeats, rng_factory):
rng = rng_factory()
onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis)
lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def testIssue1233(self):
'''
Following numpy test suite from `test_repeat` at https://github.com/numpy/numpy/blob/master/numpy/core/tests/test_multiarray.py
'''
tol = 1e-5
def test_single(m, args_maker, repeats, axis):
lax_ans = lnp.repeat(m, repeats, axis)
numpy_ans = onp.repeat(m, repeats, axis)
self.assertAllClose(lax_ans, numpy_ans, check_dtypes=True, rtol=tol, atol=tol)
lnp_fun = lambda arg: lnp.repeat(arg, repeats = repeats, axis=axis)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
m = lnp.array([1,2,3,4,5,6])
args_maker = lambda: [m]
for repeats in [2, [1,3,2,1,1,2], [1,3,0,1,1,2], [2], lnp.array([1,3,2,1,1,2]), lnp.array([2])]:
test_single(m, args_maker, repeats, None)
m_rect = m.reshape((2,3))
args_maker = lambda: [m_rect]
for repeats in [2, [2,1], [2], lnp.array([2,1]), lnp.array([2])]:
test_single(m_rect, args_maker, repeats, axis=0)
for repeats in [2, [1,3,2], [2], lnp.array([1,3,2]), lnp.array([2])]:
test_single(m_rect, args_maker, repeats, axis=1)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format(
op, jtu.format_shape_dtype_string(shape, dtype), axis, out_dtype),
"axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
"rng_factory": jtu.rand_default, "lnp_op": getattr(lnp, op),
"onp_op": getattr(onp, op)}
for op in ["cumsum", "cumprod"]
for dtype in default_dtypes
for out_dtype in default_dtypes
for shape in all_shapes
for axis in [None] + list(range(-len(shape), len(shape)))))
def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_factory):
rng = rng_factory()
onp_fun = lambda arg: onp_op(arg, axis=axis, dtype=out_dtype)
lnp_fun = lambda arg: lnp_op(arg, axis=axis, dtype=out_dtype)
args_maker = lambda: [rng(shape, dtype)]
tol = max(jtu.tolerance(dtype), jtu.tolerance(out_dtype))
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_m={}_n={}_k={}".format(
onp.dtype(dtype).name, m, n, k),
"m": m, "n": n, "k": k, "dtype": dtype, "rng_factory": jtu.rand_default}
for dtype in default_dtypes
for n in [0, 4]
for m in [None, 0, 1, 3, 4]
for k in list(range(-4, 4))))
def testTri(self, m, n, k, dtype, rng_factory):
rng = rng_factory()
onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype)
lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_shape={}_k={}".format(
op, jtu.format_shape_dtype_string(shape, dtype), k),
"dtype": dtype, "shape": shape, "op": op, "k": k,
"rng_factory": jtu.rand_default}
for dtype in default_dtypes
for shape in [shape for shape in all_shapes if len(shape) >= 2]
for op in ["tril", "triu"]
for k in list(range(-3, 3))))
def testTriLU(self, dtype, shape, op, k, rng_factory):
rng = rng_factory()
onp_fun = lambda arg: getattr(onp, op)(arg, k=k)
lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"dtype": dtype, "shape": shape, "k": k, "rng_factory": jtu.rand_default}
for dtype in default_dtypes
for shape in [shape for shape in all_shapes if len(shape) in (1, 2)]
for k in list(range(-4, 4))))
def testDiag(self, shape, dtype, k, rng_factory):
rng = rng_factory()
onp_fun = lambda arg: onp.diag(arg, k)
lnp_fun = lambda arg: lnp.diag(arg, k)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format(
jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2),
"dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1,
"axis2": axis2, "rng_factory": jtu.rand_default}
for dtype in default_dtypes
for shape in [shape for shape in all_shapes if len(shape) >= 2]
for axis1 in range(-len(shape), len(shape))
for axis2 in [a for a in range(-len(shape), len(shape))
if a % len(shape) != axis1 % len(shape)]
for offset in list(range(-4, 4))))
def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng_factory):
rng = rng_factory()
onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2)
lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n),
"dtype": dtype, "n": n}
for dtype in default_dtypes
for n in list(range(4))))
def testIdentity(self, n, dtype):
onp_fun = lambda: onp.identity(n, dtype)
lnp_fun = lambda: lnp.identity(n, dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format(
jtu.format_shape_dtype_string(shape, dtype),
out_dtype, offset, axis1, axis2),
"dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset,
"axis1": axis1, "axis2": axis2, "rng_factory": jtu.rand_default}
for dtype in default_dtypes
for out_dtype in [None] + number_dtypes
for shape in [shape for shape in all_shapes if len(shape) >= 2]
for axis1 in range(-len(shape), len(shape))
for axis2 in range(-len(shape), len(shape))
if (axis1 % len(shape)) != (axis2 % len(shape))
for offset in list(range(-4, 4))))
def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng_factory):
rng = rng_factory()
def onp_fun(arg):
if out_dtype == lnp.bfloat16:
return onp.trace(arg, offset, axis1, axis2, onp.float32).astype(lnp.bfloat16)
else:
return onp.trace(arg, offset, axis1, axis2, out_dtype)
lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis),
"shape": shape, "axis": axis, "dtypes": dtypes, "rng_factory": rng_factory}
for dtypes in [
[onp.float32],
[onp.float32, onp.float32],
[onp.float32, onp.int32, onp.float32],
[onp.float32, onp.int64, onp.float32],
[onp.float32, onp.int32, onp.float64],
]
for shape in [(), (2,), (3, 4), (1, 100)]
for axis in range(-len(shape), len(shape) + 1)
for rng_factory in [jtu.rand_default]))
def testStack(self, shape, axis, dtypes, rng_factory):
rng = rng_factory()
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
onp_fun = partial(onp.stack, axis=axis)
lnp_fun = partial(lnp.stack, axis=axis)
self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_{}".format(
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
"shape": shape, "op": op, "dtypes": dtypes, "rng_factory": rng_factory}
for op in ["hstack", "vstack", "dstack"]
for dtypes in [
[onp.float32],
[onp.float32, onp.float32],
[onp.float32, onp.int32, onp.float32],
[onp.float32, onp.int64, onp.float32],
[onp.float32, onp.int32, onp.float64],
]
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]
for rng_factory in [jtu.rand_default]))
def testHVDStack(self, shape, op, dtypes, rng_factory):
rng = rng_factory()
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
onp_fun = getattr(onp, op)
lnp_fun = getattr(lnp, op)
self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outdtype={}".format(
jtu.format_shape_dtype_string(shape, fill_value_dtype),
onp.dtype(out_dtype).name if out_dtype else "None"),
"shape": shape, "fill_value_dtype": fill_value_dtype,
"out_dtype": out_dtype, "rng_factory": jtu.rand_default}
for shape in array_shapes
for fill_value_dtype in default_dtypes
for out_dtype in [None] + default_dtypes))
def testFull(self, shape, fill_value_dtype, out_dtype, rng_factory):
rng = rng_factory()
onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)
args_maker = lambda: [rng((), fill_value_dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format(
jtu.format_shape_dtype_string(shape, in_dtype),
onp.dtype(fill_value_dtype).name,
onp.dtype(out_dtype).name),
"shape": shape, "in_dtype": in_dtype,
"fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype,
"rng_factory": jtu.rand_default}
for shape in array_shapes
for in_dtype in default_dtypes
for fill_value_dtype in default_dtypes
for out_dtype in default_dtypes))
def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng_factory):
rng = rng_factory()
onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype)
lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype)
args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_{}sections".format(
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
"shape": shape, "num_sections": num_sections, "axis": axis,
"dtype": dtype, "rng_factory": jtu.rand_default}
for shape, axis, num_sections in [
((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
for dtype in default_dtypes))
def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng_factory):
rng = rng_factory()
onp_fun = lambda x: onp.split(x, num_sections, axis=axis)
lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}_{}sections".format(
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
"shape": shape, "num_sections": num_sections, "axis": axis,
"dtype": dtype, "rng_factory": jtu.rand_default}
for shape, axis, num_sections in [
((12, 4), 0, 4), ((12, 4), 1, 2),
((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)]
for dtype in default_dtypes))
def testHVDSplit(self, shape, num_sections, axis, dtype, rng_factory):
rng = rng_factory()
def fn(module, axis):
if axis == 0:
return module.vsplit
elif axis == 1:
return module.hsplit
else:
assert axis == 2
return module.dsplit
onp_fun = lambda x: fn(onp, axis)(x, num_sections)
lnp_fun = lambda x: fn(lnp, axis)(x, num_sections)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_order={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype),
order),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
"order": order, "rng_factory": jtu.rand_default}
for dtype in default_dtypes
for order in ["C", "F"]
for arg_shape, out_shape in [
(jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)),
((), (1, 1, 1)),
((7, 0), (0, 42, 101)),
((3, 4), 12),
((3, 4), (12,)),
((3, 4), -1),
((2, 1, 4), (-1,)),
((2, 2, 4), (2, 8))
]))
def testReshape(self, arg_shape, out_shape, dtype, order, rng_factory):
rng = rng_factory()
onp_fun = lambda x: onp.reshape(x, out_shape, order=order)
lnp_fun = lambda x: lnp.reshape(x, out_shape, order=order)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype)),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
"rng_factory": jtu.rand_default}
for dtype in default_dtypes
for arg_shape, out_shape in [
((7, 0), (0, 42, 101)),
((2, 1, 4), (-1,)),
((2, 2, 4), (2, 8))
]))
def testReshapeMethod(self, arg_shape, out_shape, dtype, rng_factory):
rng = rng_factory()
onp_fun = lambda x: onp.reshape(x, out_shape)
lnp_fun = lambda x: x.reshape(*out_shape)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_expanddim={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype), dim),
"arg_shape": arg_shape, "dtype": dtype, "dim": dim,
"rng_factory": jtu.rand_default}
for arg_shape in [(), (3,), (3, 4)]
for dtype in default_dtypes
for dim in range(-len(arg_shape)+1, len(arg_shape))))
def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng_factory):
rng = rng_factory()
onp_fun = lambda x: onp.expand_dims(x, dim)
lnp_fun = lambda x: lnp.expand_dims(x, dim)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_axes=({},{})".format(
jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
"arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2,
"rng_factory": jtu.rand_default}
for arg_shape, ax1, ax2 in [
((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
for dtype in default_dtypes))
def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng_factory):
rng = rng_factory()
onp_fun = lambda x: onp.swapaxes(x, ax1, ax2)
lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_axis={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype), ax),
"arg_shape": arg_shape, "dtype": dtype, "ax": ax,
"rng_factory": jtu.rand_default}
for arg_shape, ax in [
((3, 1), None),
((3, 1), 1),
((1, 3, 1), (0, 2)),
((1, 4, 1), (0,))]
for dtype in default_dtypes))
def testSqueeze(self, arg_shape, dtype, ax, rng_factory):
rng = rng_factory()
onp_fun = lambda x: onp.squeeze(x, ax)
lnp_fun = lambda x: lnp.squeeze(x, ax)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format(
jtu.format_shape_dtype_string(shape, dtype),
axis,
(None if weights_shape is None else jtu.format_shape_dtype_string(weights_shape, dtype)),
returned),
"rng_factory": jtu.rand_default, "shape": shape, "dtype": dtype, "axis": axis,
"weights_shape": weights_shape, "returned": returned}
for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes)
for axis in set(range(-len(shape), len(shape))) | set([None])
# `weights_shape` is either `None`, same as the averaged axis, or same as
# that of the input
for weights_shape in ([None, shape] if axis is None or len(shape) == 1
else [None, (shape[axis],), shape])
for returned in [False, True]))
def testAverage(self, shape, dtype, axis, weights_shape, returned, rng_factory):
rng = rng_factory()
if weights_shape is None:
onp_fun = lambda x: onp.average(x, axis, returned=returned)
lnp_fun = lambda x: lnp.average(x, axis, returned=returned)
args_maker = lambda: [rng(shape, dtype)]
else:
onp_fun = lambda x, weights: onp.average(x, axis, weights, returned)
lnp_fun = lambda x, weights: lnp.average(x, axis, weights, returned)
args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)]
tol = {lnp.bfloat16: 1e-1, onp.float16: 1e-1, onp.float32: 1e-3,
onp.float64: 1e-10, onp.complex64: 1e-3, onp.complex128: 1e-10}
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
try:
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)
except ZeroDivisionError:
self.skipTest("don't support checking for ZeroDivisionError")
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=check_dtypes,
rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_arg{}_ndmin={}".format(i, ndmin),
"arg": arg, "ndmin": ndmin, "dtype": dtype}
for i, (arg, dtype) in enumerate([
([True, False, True], lnp.bool_),
(3., lnp.float_),
([1, 2, 3], lnp.int_),
([1., 2., 3.], lnp.float_),
([[1, 2], [3, 4], [5, 6]], lnp.int_),
([[1, 2.], [3, 4], [5, 6]], lnp.float64),
([[1., 2j], [3., 4.], [5., 6.]], lnp.complex_),
([[3, onp.array(2), 1], onp.arange(3.)], onp.float_),
])
for ndmin in [None, onp.ndim(arg), onp.ndim(arg) + 1, onp.ndim(arg) + 2]))
def testArray(self, arg, ndmin, dtype):
args_maker = lambda: [arg]
dtype = dtypes.canonicalize_dtype(dtype)
if ndmin is not None:
onp_fun = partial(onp.array, ndmin=ndmin, dtype=dtype)
lnp_fun = partial(lnp.array, ndmin=ndmin)
else:
onp_fun = partial(onp.array, dtype=dtype)
lnp_fun = lnp.array
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def testIssue121(self):
assert not onp.isscalar(lnp.array(3))
def testArrayMethod(self):
class arraylike(object):
dtype = onp.float32
def __array__(self, dtype=None):
return 3.
a = arraylike()
ans = lnp.array(a)
assert ans == 3.
@jtu.skip_on_devices("tpu") # TODO(b/32368900): TPUs don't support uint8 yet.
def testMemoryView(self):
ans = lnp.array(bytearray(b'\x2a'))
self.assertAllClose(
ans,
onp.array([0x2a], dtype=onp.uint8),
check_dtypes=True)
def testAllClose(self):
rng = onp.random.RandomState(0)
x = rng.randn(2, 2)
y = rng.randn(2)
def same(list1, list2):
allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3)
elements_close = list(map(allclose, list1, list2))
return lnp.all(lnp.array(elements_close))
csame = api.jit(same)
a1 = same((x, y), (x, y))
a2 = csame((x, y), (x, y))
a3 = csame((x, y), (x, 2 * y))
self.assertTrue(a1)
self.assertTrue(a2)
self.assertFalse(a3)
@jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure
def testOnesBroadcastingConstantHandler(self):
# TODO(mattjj): update this test for jax3
self.skipTest("test needs jax3 update")
def fun(x):
ones = lnp.ones((3, 4))
assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0)
# To check that the constant handler generates a Broadcast for stride-zero
# arrays, we monkey-patch the client instance.
# TODO(mattjj): once we have better HLO dumping and inspecting facilities,
# we can check the HLO more directly.
c = x._node.c
Broadcast = c.Broadcast # pylint: disable=invalid-name
was_called = []
c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args)
out = x + ones # the ndarray constant handler should call Broadcast here
assert was_called, "Broadcast was not called."
return out
fun = api.jit(fun)
out_val = fun(lnp.ones(4))
self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)
def testZeroStridesConstantHandler(self):
raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1)
const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))
def fun(x):
return x * const
fun = api.jit(fun)
out_val = fun(3.)
self.assertAllClose(out_val, 3. * const, check_dtypes=False)
def testIsInstanceNdarrayDuringTracing(self):
arr = onp.ones(3)
@api.jit
def f(x):
self.assertIsInstance(x, lnp.ndarray)
return lnp.sum(x)
f(arr)
def testNonArrayErrorMessage(self):
x = [1., 2.]
y = onp.array([3., 4.])
def g(x, y):
return lnp.add(x, y)
def f(x, y):
return lnp.dot(x, y)
self.assertRaises(TypeError, lambda: g(x, y))
self.assertRaises(TypeError, lambda: f(x, y))
self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
self.assertRaises(TypeError, lambda: api.jit(f)(x, y))
def testAbstractionErrorMessage(self):
@api.jit
def f(x, n):
for _ in range(n):
x = x * x
return x
self.assertRaises(TypeError, lambda: f(3., 3))
@api.jit
def g(x):
if x > 0.:
return x * 2
else:
return x + 2
self.assertRaises(TypeError, lambda: g(3.))
def testTracingPrimitiveWithNoTranslationErrorMessage(self):
# TODO(mattjj): update this for jax3
self.skipTest("test needs jax3 update")
foo = lnp._not_implemented(lambda x: x)
# No error if there's no tracing.
foo(onp.arange(3))
cfoo = api.jit(foo)
self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis}
for shape in [(3,), (2, 3)]
for dtype in default_dtypes
for axis in list(range(-len(shape), len(shape))) + [None] # Test negative axes
for rng_factory in [jtu.rand_default]))
def testFlip(self, shape, dtype, axis, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
lnp_op = lambda x: lnp.flip(x, axis)
onp_op = lambda x: onp.flip(x, axis)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype}
for shape in [(3,), (2, 3), (3, 2, 4)]
for dtype in default_dtypes
for rng_factory in [jtu.rand_default]))
def testFlipud(self, shape, dtype, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
lnp_op = lambda x: lnp.flipud(x)
onp_op = lambda x: onp.flipud(x)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype}
for shape in [(3, 2), (2, 3), (3, 2, 4)]
for dtype in default_dtypes
for rng_factory in [jtu.rand_default]))
def testFliplr(self, shape, dtype, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
lnp_op = lambda x: lnp.fliplr(x)
onp_op = lambda x: onp.fliplr(x)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_k={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, axes),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k, "axes": axes}
for shape, axes in [
[(2, 3), (0, 1)],
[(2, 3), (1, 0)],
[(4, 3, 2), (0, 2)],
[(4, 3, 2), (2, 1)],
]
for k in range(-3, 4)
for dtype in default_dtypes
for rng_factory in [jtu.rand_default]))
def testRot90(self, shape, dtype, k, axes, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
lnp_op = lambda x: lnp.rot90(x, k, axes)
onp_op = lambda x: onp.rot90(x, k, axes)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
# TODO(mattjj): test infix operator overrides
def testRavel(self):
rng = onp.random.RandomState(0)
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
def testAstype(self):
rng = onp.random.RandomState(0)
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
op = lambda x: x.astype(lnp.int32)
self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
# TODO(mattjj): test other ndarray-like method overrides
def testOnpMean(self):
# from https://github.com/google/jax/issues/125
x = lax.add(lnp.eye(3, dtype=lnp.float_), 0.)
ans = onp.mean(x)
self.assertAllClose(ans, onp.array(1./3), check_dtypes=False)
def testArangeOnFloats(self):
# from https://github.com/google/jax/issues/145
expected = onp.arange(0.0, 1.0, 0.1, dtype=lnp.float_)
ans = lnp.arange(0.0, 1.0, 0.1)
self.assertAllClose(expected, ans, check_dtypes=True)
def testSortManually(self):
# manual tests for sort are nice because we don't have to worry about ties.
# lax.sort is tested combinatorially.
ans = lnp.sort(onp.array([16, 15, 23, 42, 8, 4]))
expected = onp.array([4, 8, 15, 16, 23, 42])
self.assertAllClose(expected, ans, check_dtypes=True)
a = onp.array([[1, 4], [3, 1]])
ans = lnp.sort(a, axis=None)
expected = onp.array([1, 1, 3, 4])
self.assertAllClose(expected, ans, check_dtypes=True)
a = onp.array([[1, 4], [3, 1]])
ans = lnp.sort(a) # last axis
expected = onp.array([[1, 4], [1, 3]])
self.assertAllClose(expected, ans, check_dtypes=True)
a = onp.array([[1, 4], [3, 1]])
ans = lnp.sort(a, axis=0)
expected = onp.array([[1, 1], [3, 4]])
self.assertAllClose(expected, ans, check_dtypes=True)
def testArgsortManually(self):
x = onp.array([16, 15, 23, 42, 8, 4])
ans = lnp.argsort(x)
expected = onp.argsort(x)
self.assertAllClose(expected, ans, check_dtypes=False)
x = onp.array([[16, 15, 23], [42, 8, 4]])
ans = lnp.argsort(x, axis=0)
expected = onp.argsort(x, axis=0)
self.assertAllClose(expected, ans, check_dtypes=False)
x = onp.array([[16, 15, 23], [42, 8, 4]])
ans = lnp.argsort(x, axis=1)
expected = onp.argsort(x, axis=1)
self.assertAllClose(expected, ans, check_dtypes=False)
x = onp.array([[16, 15, 23], [42, 8, 4]])
ans = lnp.argsort(x, axis=None)
expected = onp.argsort(x, axis=None)
self.assertAllClose(expected, ans, check_dtypes=False)
x = onp.array([[16, 15, 23], [42, 8, 4]])
ans = lnp.argsort(x)
expected = onp.argsort(x)
self.assertAllClose(expected, ans, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_shifts={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype),
shifts, axis),
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "shifts": shifts,
"axis": axis}
for dtype in all_dtypes
for shape in [(3, 4), (3, 4, 5), (7, 4, 0)]
for shifts, axis in [
(3, None),
(1, 1),
((3,), (0,)),
((-2,), (-2,)),
((1, 2), (0, -1))
]
for rng_factory in [jtu.rand_default]))
def testRoll(self, shape, dtype, shifts, axis, rng_factory):
rng = rng_factory()
args_maker = lambda: [rng(shape, dtype), onp.array(shifts)]
lnp_op = partial(lnp.roll, axis=axis)
onp_op = partial(onp.roll, axis=axis)
self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_index={}_axis={}_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype),
jtu.format_shape_dtype_string(index_shape, index_dtype),
axis, mode),
"rng_factory": rng_factory, "rng_indices_factory": rng_indices_factory,
"shape": shape, "index_shape": index_shape, "dtype": dtype,
"index_dtype": index_dtype, "axis": axis, "mode": mode}
for shape in [(3,), (3, 4), (3, 4, 5)]
for index_shape in scalar_shapes + [(3,), (2, 1, 3)]
for axis in itertools.chain(range(-len(shape), len(shape)), [None])
for dtype in all_dtypes
for index_dtype in int_dtypes
for mode in ['wrap', 'clip']
for rng_factory in [jtu.rand_default]
for rng_indices_factory in [partial(jtu.rand_int, -5, 5)]))
def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode,
rng_factory, rng_indices_factory):
def args_maker():
x = rng(shape, dtype)
i = rng_indices(index_shape, index_dtype)
return x, i
rng = rng_factory()
rng_indices = rng_indices_factory()
lnp_op = lambda x, i: lnp.take(x, i, axis=axis, mode=mode)
onp_op = lambda x, i: onp.take(x, i, axis=axis, mode=mode)
self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_ishape={}_axis={}".format(
jtu.format_shape_dtype_string(x_shape, dtype), i_shape, axis),
"rng_factory": rng_factory, "x_shape": x_shape, "i_shape": i_shape, "dtype": dtype,
"axis": axis}
for x_shape, i_shape in filter(
_shapes_are_equal_length,
filter(_shapes_are_broadcast_compatible,
CombosWithReplacement(nonempty_nonscalar_array_shapes, 2)))
for axis in itertools.chain(range(len(x_shape)), [-1], [None])
for dtype in default_dtypes
for rng_factory in [jtu.rand_default]))
def testTakeAlongAxis(self, x_shape, i_shape, dtype, axis, rng_factory):
rng = rng_factory()
i_shape = onp.array(i_shape)
if axis is None:
i_shape = [onp.prod(i_shape, dtype=onp.int64)]
else:
# Test the case where the size of the axis doesn't necessarily broadcast.
i_shape[axis] *= 3
i_shape = list(i_shape)
def args_maker():
x = rng(x_shape, dtype)
n = onp.prod(x_shape, dtype=onp.int32) if axis is None else x_shape[axis]
i = rng(i_shape, onp.int32) % (2 * n - 1) - (n - 1)
return x, i
lnp_op = lambda x, i: lnp.take_along_axis(x, i, axis=axis)
if hasattr(onp, "take_along_axis"):
onp_op = lambda x, i: onp.take_along_axis(x, i, axis=axis)
self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_n={}_increasing={}".format(
jtu.format_shape_dtype_string([shape], dtype),
n, increasing),
"dtype": dtype, "shape": shape, "n": n, "increasing": increasing,
"rng_factory": jtu.rand_default}
for dtype in inexact_dtypes
for shape in [0, 5]
for n in [2, 4]
for increasing in [False, True]))
def testVander(self, shape, dtype, n, increasing, rng_factory):
rng = rng_factory()
def onp_fun(arg):
arg = arg.astype(onp.float32) if dtype == lnp.bfloat16 else arg
return onp.vander(arg, N=n, increasing=increasing)
lnp_fun = lambda arg: lnp.vander(arg, N=n, increasing=increasing)
args_maker = lambda: [rng([shape], dtype)]
# np.vander seems to return float64 for all floating types. We could obey
# those semantics, but they seem like a bug.
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False,
tol={onp.float32: 1e-3})
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("nan_to_num", [shape],
[dtype]),
"rng_factory": jtu.rand_some_inf_and_nan, "shape": shape,
"dtype": dtype}
for shape in all_shapes
for dtype in inexact_dtypes))
def testNanToNum(self, rng_factory, shape, dtype):
rng = rng_factory()
dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type
def onp_fun(x):
if dtype == lnp.bfloat16:
x = onp.where(onp.isnan(x), dtype(0), x)
x = onp.where(onp.isposinf(x), lnp.finfo(dtype).max, x)
x = onp.where(onp.isneginf(x), lnp.finfo(dtype).min, x)
return x
else:
return onp.nan_to_num(x).astype(dtype)
args_maker = lambda: [rng(shape, dtype)]
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
self._CheckAgainstNumpy(onp_fun, lnp.nan_to_num, args_maker,
check_dtypes=check_dtypes)
self._CompileAndCheck(lnp.nan_to_num, args_maker,
check_dtypes=check_dtypes)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes),
"rng_factory": jtu.rand_default, "shapes": shapes, "dtypes": dtypes}
for shapes, dtypes in (
((), ()),
(((7,),), (onp.int32,)),
(((3,), (4,)), (onp.int32, onp.int32)),
(((3,), (1,), (4,)), (onp.int32, onp.int32, onp.int32)),
)))
def testIx_(self, rng_factory, shapes, dtypes):
rng = rng_factory()
args_maker = lambda: [rng(shape, dtype)
for shape, dtype in zip(shapes, dtypes)]
self._CheckAgainstNumpy(onp.ix_, lnp.ix_, args_maker,
check_dtypes=True)
self._CompileAndCheck(lnp.ix_, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}".format(
op,
jtu.format_shape_dtype_string(a_shape, a_dtype),
jtu.format_shape_dtype_string(q_shape, q_dtype),
axis, keepdims),
"a_rng": jtu.rand_default(), "q_rng": q_rng, "op": op,
"a_shape": a_shape, "a_dtype": a_dtype,
"q_shape": q_shape, "q_dtype": q_dtype, "axis": axis,
"keepdims": keepdims}
for (op, q_rng) in (
("percentile", jtu.rand_uniform(low=0., high=100.)),
("quantile", jtu.rand_uniform(low=0., high=1.)),
("median", jtu.rand_uniform(low=0., high=1.)),
)
for a_dtype in float_dtypes
for a_shape, axis in (
((7,), None),
((47, 7), 0),
((4, 101), 1),
)
for q_dtype in [onp.float32]
for q_shape in scalar_shapes + [(4,)]
for keepdims in [False, True]))
def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype,
axis, keepdims):
if op == "quantile" and numpy_version < (1, 15):
raise SkipTest("Numpy < 1.15 does not have np.quantile")
if op == "median":
args_maker = lambda: [a_rng(a_shape, a_dtype)]
else:
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
def onp_fun(*args):
args = [x if lnp.result_type(x) != lnp.bfloat16 else
onp.asarray(x, onp.float32) for x in args]
return getattr(onp, op)(*args, axis=axis, keepdims=keepdims)
lnp_fun = partial(getattr(lnp, op), axis=axis, keepdims=keepdims)
# TODO(phawkins): we currently set dtype=False because we aren't as
# aggressive about promoting to float64. It's not clear we want to mimic
# Numpy here.
tol_spec = {onp.float32: 1e-4, onp.float64: 5e-6}
tol = max(jtu.tolerance(a_dtype, tol_spec),
jtu.tolerance(q_dtype, tol_spec))
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("select", shapes,
(onp.bool_,) * n + dtypes),
"rng_factory": jtu.rand_default, "shapes": shapes, "dtypes": dtypes}
for n in range(0, 3)
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(all_shapes, 2 * n + 1))
for dtypes in CombosWithReplacement(all_dtypes, n + 1)))
def test(self, rng_factory, shapes, dtypes):
rng = rng_factory()
n = len(dtypes) - 1
def args_maker():
condlist = [rng(shape, onp.bool_) for shape in shapes[:n]]
choicelist = [rng(shape, dtype)
for shape, dtype in zip(shapes[n:-1], dtypes[:n])]
default = rng(shapes[-1], dtypes[-1])
return condlist, choicelist, default
# TODO(phawkins): float32/float64 type mismatches
def onp_fun(condlist, choicelist, default):
choicelist = [x if lnp.result_type(x) != lnp.bfloat16
else x.astype(onp.float32) for x in choicelist]
dtype = lnp.result_type(default, *choicelist)
return onp.select(condlist, choicelist, default).astype(dtype)
self._CheckAgainstNumpy(onp_fun, lnp.select, args_maker,
check_dtypes=False)
self._CompileAndCheck(lnp.select, args_maker, check_dtypes=True)
def testIssue330(self):
x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash
self.assertEqual(x[0, 0], 1)
def testScalarDtypePromotion(self):
# disabled this test after https://github.com/google/jax/issues/732
msg = ("jax.numpy differs from numpy in promotion rules for Python scalars."
" See https://github.com/google/jax/issues/732.")
raise SkipTest(msg)
orig_numpy_result = (1 + onp.eye(1, dtype=onp.float32)).dtype
jax_numpy_result = (1 + lnp.eye(1, dtype=lnp.float32)).dtype
self.assertEqual(orig_numpy_result, jax_numpy_result)
def testSymmetrizeDtypePromotion(self):
x = onp.eye(3, dtype=onp.float32)
orig_numpy_result = ((x + x.T) / 2).dtype
x = lnp.eye(3, dtype=lnp.float32)
jax_numpy_result = ((x + x.T) / 2).dtype
self.assertEqual(orig_numpy_result, jax_numpy_result)
def testIssue347(self):
# https://github.com/google/jax/issues/347
def test_fail(x):
x = lnp.sqrt(lnp.sum(x ** 2, axis=1))
ones = lnp.ones_like(x)
x = lnp.where(x > 0.5, x, ones)
return lnp.sum(x)
x = lnp.array([[1, 2], [3, 4], [0, 0]], dtype=lnp.float64)
result = api.grad(test_fail)(x)
assert not onp.any(onp.isnan(result))
def testIssue453(self):
# https://github.com/google/jax/issues/453
a = onp.arange(6) + 1
ans = lnp.reshape(a, (3, 2), order='F')
expected = onp.reshape(a, (3, 2), order='F')
self.assertAllClose(ans, expected, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_dtype={}".format(op, pytype.__name__),
"pytype": pytype, "dtype": dtype, "op": op}
for pytype, dtype in [(int, lnp.int_), (float, lnp.float_),
(bool, lnp.bool_), (complex, lnp.complex_)]
for op in ["atleast_1d", "atleast_2d", "atleast_3d"]))
def testAtLeastNdLiterals(self, pytype, dtype, op):
# Fixes: https://github.com/google/jax/issues/634
onp_fun = lambda arg: getattr(onp, op)(arg).astype(dtype)
lnp_fun = lambda arg: getattr(lnp, op)(arg)
args_maker = lambda: [pytype(2)]
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def testLongLong(self):
self.assertAllClose(onp.int64(7), api.jit(lambda x: x)(onp.longlong(7)),
check_dtypes=True)
def testArange(self):
# test cases inspired by dask tests at
# https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92
self.assertAllClose(lnp.arange(77),
onp.arange(77, dtype=lnp.int_), check_dtypes=True)
self.assertAllClose(lnp.arange(2, 13),
onp.arange(2, 13, dtype=lnp.int_), check_dtypes=True)
self.assertAllClose(lnp.arange(4, 21, 9),
onp.arange(4, 21, 9, dtype=lnp.int_), check_dtypes=True)
self.assertAllClose(lnp.arange(53, 5, -3),
onp.arange(53, 5, -3, dtype=lnp.int_),
check_dtypes=True)
# TODO(mattjj): make these tests work when jax_enable_x64=True
# self.assertAllClose(lnp.arange(77, dtype=float),
# onp.arange(77, dtype=float), check_dtypes=True)
# self.assertAllClose(lnp.arange(2, 13, dtype=int),
# onp.arange(2, 13, dtype=int), check_dtypes=True)
self.assertAllClose(lnp.arange(0, 1, -0.5),
onp.arange(0, 1, -0.5), check_dtypes=True)
self.assertRaises(TypeError, lambda: lnp.arange())
# test that lnp.arange(N) doesn't instantiate an ndarray
self.assertFalse(type(lnp.arange(77)) == type(onp.arange(77)))
self.assertTrue(type(lnp.arange(77)) == type(lax.iota(onp.int32, 77)))
# test that lnp.arange(N, dtype=int32) doesn't instantiate an ndarray
self.assertFalse(type(lnp.arange(77, dtype=lnp.int32)) ==
type(onp.arange(77, dtype=onp.int32)))
self.assertTrue(type(lnp.arange(77, dtype=lnp.int32)) ==
type(lax.iota(onp.int32, 77)))
def testIssue830(self):
a = lnp.arange(4, dtype=lnp.complex64)
self.assertEqual(a.dtype, lnp.complex64)
def testIssue728(self):
assert lnp.allclose(lnp.eye(5000), onp.eye(5000))
self.assertEqual(0, onp.sum(lnp.eye(1050) - onp.eye(1050)))
def testIssue746(self):
lnp.arange(12).reshape(3, 4) # doesn't crash
def testIssue764(self):
x = lnp.linspace(190, 200, 4)
f = api.grad(lambda x: lnp.sum(lnp.tanh(x)))
# Expected values computed with autograd in float64 precision.
expected = onp.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171,
7.66067839e-174], onp.float64)
self.assertAllClose(f(x), expected, check_dtypes=False)
def testIssue776(self):
"""Tests that the scatter-add transpose rule instantiates symbolic zeros."""
def f(u):
y = jax.ops.index_add(onp.ones(10,), [2, 4, 5], u)
# The transpose rule for lax.tie_in returns a symbolic zero for its first
# argument.
return lax.tie_in(y, 7.)
self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)),
check_dtypes=True)
def testIssue777(self):
x = lnp.linspace(-200, 0, 4, dtype=onp.float32)
f = api.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x))))
self.assertAllClose(f(x), onp.array([0., 0., 0., 0.25], dtype=onp.float32),
check_dtypes=True)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]),
"dtype": dtype, "op": op}
for dtype in float_dtypes
for op in ("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan",
"sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp",
"log", "expm1", "log1p")))
def testMathSpecialFloatValues(self, op, dtype):
onp_op = getattr(onp, op)
lnp_op = getattr(lnp, op)
dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type
for x in (onp.nan, -onp.inf, -100., -2., -1., 0., 1., 2., 100., onp.inf,
lnp.finfo(dtype).max, onp.sqrt(lnp.finfo(dtype).max),
onp.sqrt(lnp.finfo(dtype).max) * 2.):
if onp.isnan(x) and op in ("sinh", "cosh", "expm1", "exp"):
# TODO(b/133842876, b/133842870): these return wrong outputs on CPU for
# NaN inputs.
continue
if (op in ("sin", "cos", "tan", "arctan") and
jtu.device_under_test() == "tpu"):
continue # TODO(b/132196789, b/134175194): fix and reenable.
x = dtype(x)
expected = onp_op(x)
actual = lnp_op(x)
tol = jtu.tolerance(dtype, {onp.float32: 1e-3, onp.float64: 1e-7})
self.assertAllClose(expected, actual, check_dtypes=True, atol=tol,
rtol=tol)
def testIssue883(self):
# from https://github.com/google/jax/issues/883
@partial(api.jit, static_argnums=(1,))
def f(x, v):
return x
x = lnp.ones((10, 10))
v = lnp.array([1, 2, 3])
first_call = f(x, v)
second_call = f(x, v) # doesn't crash
def testReductionOfOutOfBoundsAxis(self): # Issue 888
x = lnp.ones((3, 4))
self.assertRaises(ValueError, lambda: lnp.sum(x, axis=2))
def testIssue956(self):
self.assertRaises(TypeError, lambda: lnp.ndarray((1, 1)))
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name":
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
.format(shape, dtype, out_dtype, axis, ddof, keepdims),
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
"ddof": ddof, "keepdims": keepdims, "rng_factory": rng_factory}
for shape in [(5,), (10, 5)]
for dtype in all_dtypes
for out_dtype in inexact_dtypes
for axis in [None, 0, -1]
for ddof in [0, 1, 2]
for keepdims in [False, True]
for rng_factory in [jtu.rand_default]))
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
def onp_fun(x):
out = onp.var(x.astype(lnp.promote_types(onp.float32, dtype)),
axis=axis, ddof=ddof, keepdims=keepdims)
return out.astype(out_dtype)
lnp_fun = partial(lnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {onp.float16: 1e-1, onp.float32: 1e-3,
onp.float64: 1e-3, onp.complex128: 1e-6})
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True,
tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol,
atol=tol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
shape, dtype, rowvar, ddof, bias),
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
"bias": bias, "rng_factory": rng_factory}
for shape in [(5,), (10, 5), (5, 10)]
for dtype in all_dtypes
for rowvar in [True, False]
for bias in [True, False]
for ddof in [None, 2, 3]
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("gpu") # TODO(b/138003641): test fails on GPU.
def testCov(self, shape, dtype, rowvar, ddof, bias, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
onp_fun = partial(onp.cov, rowvar=rowvar, ddof=ddof, bias=bias)
lnp_fun = partial(lnp.cov, rowvar=rowvar, ddof=ddof, bias=bias)
tol = {onp.float32: 1e-5, onp.float64: 1e-13, onp.complex128: 1e-13}
tol = 7e-2 if jtu.device_under_test() == "tpu" else tol
tol = jtu.join_tolerance(tol, jtu.tolerance(dtype))
self._CheckAgainstNumpy(
onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol,
rtol=tol)
def testIssue967(self):
self.assertRaises(TypeError, lambda: lnp.zeros(1.5))
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
shape, dtype, rowvar, ddof, bias),
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
"bias": bias, "rng_factory": rng_factory}
for shape in [(5,), (10, 5), (3, 10)]
for dtype in number_dtypes
for rowvar in [True, False]
for bias in [True, False]
for ddof in [None, 2, 3]
for rng_factory in [jtu.rand_default]))
def testCorrCoef(self, shape, dtype, rowvar, ddof, bias, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
mat = onp.asarray([rng(shape, dtype)])
onp_fun = partial(onp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias)
lnp_fun = partial(lnp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias)
if not onp.any(onp.isclose(onp.std(mat), 0.0)):
self._CheckAgainstNumpy(
onp_fun, lnp_fun, args_maker, check_dtypes=True,
tol=1e-2 if jtu.device_under_test() == "tpu" else None)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": "_shapes={}_dtype={}_indexing={}_sparse={}".format(
shapes, dtype, indexing, sparse),
"shapes": shapes, "dtype": dtype, "indexing": indexing,
"sparse": sparse, "rng_factory": rng_factory}
for shapes in [(), (5,), (5, 3)]
for dtype in number_dtypes
for indexing in ['xy', 'ij']
for sparse in [True, False]
for rng_factory in [jtu.rand_default]))
def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes],
[dtype] * len(shapes))
onp_fun = partial(onp.meshgrid, indexing=indexing, sparse=sparse)
lnp_fun = partial(lnp.meshgrid, indexing=indexing, sparse=sparse)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def assertDowncastDtypeEqual(self, x, y):
"""Heuristic for comparing numpy and jax downcast dtypes."""
x_dt = jtu._dtype(x)
y_dt = jtu._dtype(y)
testing_tpu = jtu.device_under_test().startswith("tpu")
testing_x32 = not jax.config.read('jax_enable_x64')
to32dtype = {onp.int64: onp.int32, onp.uint64: onp.uint32,
onp.float64: onp.float32, onp.float128: onp.float32,
onp.complex128: onp.complex64, onp.complex256: onp.complex64}
to32dtype = {onp.dtype(k): onp.dtype(v) for k,v in to32dtype.items()}
if testing_tpu or testing_x32:
x_dt = to32dtype.get(x_dt, x_dt)
y_dt = to32dtype.get(y_dt, y_dt)
assert x_dt == y_dt, "truncated dtypes %s != %s" % (x_dt, y_dt)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
"_retstep={}_dtype={}").format(
start_shape, stop_shape, num, endpoint, retstep, dtype),
"start_shape": start_shape, "stop_shape": stop_shape,
"num": num, "endpoint": endpoint, "retstep": retstep,
"dtype": dtype, "rng_factory": rng_factory}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
for retstep in [True, False]
for dtype in number_dtypes + [None,]
for rng_factory in [jtu.rand_default]))
def testLinspace(self, start_shape, stop_shape, num, endpoint,
retstep, dtype, rng_factory):
rng = rng_factory()
# relax default tolerances slightly
tol = jtu.tolerance(dtype if dtype else onp.float32) * 10
args_maker = self._GetArgsMaker(rng,
[start_shape, stop_shape],
[dtype, dtype])
start, stop = args_maker()
ndim = len(onp.shape(start + stop))
for axis in range(-ndim, ndim):
lnp_op = lambda start, stop: lnp.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
onp_op = lambda start, stop: onp.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker,
check_dtypes=False, tol=tol)
# Check dtype equivalence within expected 32bit downcasting.
a, b = lnp_op(start, stop), onp_op(start, stop)
if retstep:
self.assertDowncastDtypeEqual(a[0], b[0])
self.assertDowncastDtypeEqual(a[1], b[1])
else:
self.assertDowncastDtypeEqual(a, b)
# floating-point compute between jitted platforms and non-jit + rounding
# cause unavoidable variation in integer truncation for some inputs.
if dtype in (inexact_dtypes + [None,]):
self._CompileAndCheck(lnp_op, args_maker,
check_dtypes=False, atol=tol, rtol=tol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
"_base={}_dtype={}").format(
start_shape, stop_shape, num, endpoint, base,
dtype.__name__ if dtype else "None"),
"start_shape": start_shape,
"stop_shape": stop_shape,
"num": num, "endpoint": endpoint, "base": base,
"dtype": dtype, "rng_factory": rng_factory}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
for base in [10.0, 2, onp.e]
for dtype in inexact_dtypes + [None,]
for rng_factory in [jtu.rand_default]))
def testLogspace(self, start_shape, stop_shape, num,
endpoint, base, dtype, rng_factory):
if (dtype in int_dtypes and
jtu.device_under_test() in ("gpu", "tpu") and
not FLAGS.jax_enable_x64):
raise unittest.SkipTest("GPUx32 truncated exponentiation"
" doesn't exactly match other platforms.")
rng = rng_factory()
# relax default tolerances slightly
tol = {onp.float16: 2e-2, onp.float32: 1e-2, onp.float64: 1e-6,
onp.complex64: 1e-3, onp.complex128: 1e-6}
args_maker = self._GetArgsMaker(rng,
[start_shape, stop_shape],
[dtype, dtype])
start, stop = args_maker()
ndim = len(onp.shape(start + stop))
for axis in range(-ndim, ndim):
lnp_op = lambda start, stop: lnp.logspace(
start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis)
onp_op = lambda start, stop: onp.logspace(
start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker,
check_dtypes=False, tol=tol)
# Check dtype equivalence within expected 32bit downcasting.
a, b = lnp_op(start, stop), onp_op(start, stop)
self.assertDowncastDtypeEqual(a, b)
if dtype in (inexact_dtypes + [None,]):
# Why do compiled and op-by-op float16 np.power numbers differ
# slightly more than expected?
atol = {onp.float16: 1e-2}
self._CompileAndCheck(lnp_op, args_maker,
check_dtypes=False, atol=atol, rtol=tol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
"_dtype={}").format(
start_shape, stop_shape, num, endpoint, dtype),
"start_shape": start_shape,
"stop_shape": stop_shape,
"num": num, "endpoint": endpoint,
"dtype": dtype, "rng_factory": rng_factory}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
# NB: numpy's geomspace gives nonsense results on integer types
for dtype in inexact_dtypes + [None,]
for rng_factory in [jtu.rand_default]))
def testGeomspace(self, start_shape, stop_shape, num,
endpoint, dtype, rng_factory):
rng = rng_factory()
# relax default tolerances slightly
tol = {onp.float16: 4e-3, onp.float32: 2e-3, onp.complex128: 1e-14}
def args_maker():
"""Test the set of inputs onp.geomspace is well-defined on."""
start, stop = self._GetArgsMaker(rng,
[start_shape, stop_shape],
[dtype, dtype])()
# onp.geomspace can't handle differently ranked tensors
# w. negative numbers!
start, stop = lnp.broadcast_arrays(start, stop)
if dtype in complex_dtypes:
return start, stop
# to avoid NaNs, non-complex start and stop cannot
# differ in sign, elementwise
start = start * lnp.sign(start) * lnp.sign(stop)
return start, stop
start, stop = args_maker()
ndim = len(onp.shape(start + stop))
for axis in range(-ndim, ndim):
def lnp_op(start, stop):
return lnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype,
axis=axis)
def onp_op(start, stop):
start = start.astype(onp.float32) if dtype == lnp.bfloat16 else start
stop = stop.astype(onp.float32) if dtype == lnp.bfloat16 else stop
return onp.geomspace(
start, stop, num, endpoint=endpoint,
dtype=dtype if dtype != lnp.bfloat16 else onp.float32,
axis=axis).astype(dtype)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker,
check_dtypes=False, tol=tol)
# Check dtype equivalence within expected 32bit downcasting.
a, b = lnp_op(start, stop), onp_op(start, stop)
self.assertDowncastDtypeEqual(a, b)
if dtype in (inexact_dtypes + [None,]):
self._CompileAndCheck(lnp_op, args_maker,
check_dtypes=False, atol=tol, rtol=tol)
def testDisableNumpyRankPromotionBroadcasting(self):
try:
prev_flag = FLAGS.jax_numpy_rank_promotion
FLAGS.jax_numpy_rank_promotion = "allow"
lnp.ones(2) + lnp.ones((1, 2)) # works just fine
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
try:
prev_flag = FLAGS.jax_numpy_rank_promotion
FLAGS.jax_numpy_rank_promotion = "raise"
self.assertRaises(ValueError, lambda: lnp.ones(2) + lnp.ones((1, 2)))
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
try:
prev_flag = FLAGS.jax_numpy_rank_promotion
FLAGS.jax_numpy_rank_promotion = "warn"
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
lnp.ones(2) + lnp.ones((1, 2))
assert len(w) > 0
msg = str(w[-1].message)
expected_msg = ("Following NumPy automatic rank promotion for add on "
"shapes (2,) (1, 2).")
self.assertEqual(msg[:len(expected_msg)], expected_msg)
prev_len = len(w)
lnp.ones(2) + 3
self.assertEqual(len(w), prev_len) # don't want to warn for scalars
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
def testStackArrayArgument(self):
# tests https://github.com/google/jax/issues/1271
@api.jit
def foo(x):
return lnp.stack(x)
foo(onp.zeros(2)) # doesn't crash
@api.jit
def foo(x):
return lnp.concatenate(x)
foo(onp.zeros((2, 2))) # doesn't crash
def testReluGradientConstants(self):
# This is a regression test that verifies that constants associated with the
# gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the
# outermost jaxpr. This was producing some large materialized constants for
# every relu activation in a model.
def body(i, xy):
x, y = xy
y = y + jax.grad(lambda z: lnp.sum(lnp.maximum(z, 0.)))(x)
return x, y
f = lambda y: lax.fori_loop(0, 5, body, (y, y))
wrapped = linear_util.wrap_init(f)
pv = partial_eval.PartialVal(
(jax.ShapedArray((3, 4), onp.float32), jax.core.unit))
_, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv])
self.assertFalse(
any(onp.array_equal(x, onp.full((3, 4), 2., dtype=onp.float32))
for x in consts))
@parameterized.named_parameters(
{"testcase_name": "_from={}_to={}".format(from_shape, to_shape),
"rng_factory": rng_factory, "from_shape": from_shape, "to_shape": to_shape}
for from_shape, to_shape in [
[(1, 3), (4, 3)],
[(3,), (2, 1, 3)],
[(3,), (3, 3)],
[(1,), (3,)],
]
for rng_factory in [jtu.rand_default])
def testBroadcastTo(self, from_shape, to_shape, rng_factory):
rng = rng_factory()
args_maker = self._GetArgsMaker(rng, [from_shape], [onp.float32])
onp_op = lambda x: onp.broadcast_to(x, to_shape)
lnp_op = lambda x: lnp.broadcast_to(x, to_shape)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
def testBroadcastToIssue1522(self):
self.assertRaisesRegex(
ValueError, "Incompatible shapes for broadcasting: .*",
lambda: lnp.broadcast_to(onp.ones((2, 3)), (1, 3)))
def testBroadcastToIntIssue1548(self):
self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)),
check_dtypes=False)
def testBroadcastToOnScalar(self):
self.assertIsInstance(lnp.broadcast_to(10.0, ()), lnp.ndarray)
self.assertIsInstance(onp.broadcast_to(10.0, ()), onp.ndarray)
def testPrecision(self):
def iter_eqns(jaxpr):
for eqn in jaxpr.eqns:
yield eqn
for subjaxpr, _, _ in eqn.bound_subjaxprs:
for sub_eqn in iter_eqns(subjaxpr):
yield sub_eqn
def assert_precision(expected, fun, *args):
jaxpr = jax.make_jaxpr(fun)(*args)
precision, = [eqn.params['precision'] for eqn in iter_eqns(jaxpr.jaxpr)
if eqn.primitive == lax.dot_general_p]
self.assertEqual(precision, expected)
ones_1d = onp.ones((2,))
ones_2d = onp.ones((2, 2))
ones_3d = onp.ones((2, 2, 2))
HIGHEST = lax.Precision.HIGHEST
assert_precision(None, lnp.dot, ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.dot, precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.dot, precision=HIGHEST),
ones_3d, ones_3d)
assert_precision(
HIGHEST,
partial(lnp.matmul, precision=HIGHEST),
ones_2d, ones_2d)
assert_precision(
HIGHEST,
partial(lnp.vdot, precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.tensordot, axes=2, precision=HIGHEST),
ones_2d, ones_2d)
assert_precision(
HIGHEST,
partial(lnp.tensordot, axes=(0, 0), precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.tensordot, axes=((0,), (0,)), precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.einsum, 'i,i', precision=HIGHEST),
ones_1d, ones_1d)
assert_precision(
HIGHEST,
partial(lnp.einsum, 'ij,ij', precision=HIGHEST),
ones_2d, ones_2d)
assert_precision(
HIGHEST,
partial(lnp.inner, precision=HIGHEST),
ones_1d, ones_1d)
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.
GradTestSpec = collections.namedtuple(
"GradTestSpec",
["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"])
def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
return GradTestSpec(
op, nargs, order, rng_factory, dtypes, name or op.__name__, tol)
GRAD_TEST_RECORDS = [
grad_test_spec(lnp.arcsinh, nargs=1, order=2,
rng_factory=jtu.rand_positive,
dtypes=[onp.float64, onp.complex64], tol=1e-4),
grad_test_spec(lnp.arccosh, nargs=1, order=2,
rng_factory=jtu.rand_positive,
dtypes=[onp.float64, onp.complex64], tol=1e-4),
grad_test_spec(lnp.arctanh, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, -0.9, 0.9),
dtypes=[onp.float64, onp.complex64], tol=1e-4),
]
GradSpecialValuesTestSpec = collections.namedtuple(
"GradSpecialValuesTestSpec", ["op", "values", "order"])
GRAD_SPECIAL_VALUE_TEST_RECORDS = [
GradSpecialValuesTestSpec(lnp.arcsinh, [0., 1000.], 2),
GradSpecialValuesTestSpec(lnp.arccosh, [1000.], 2),
GradSpecialValuesTestSpec(lnp.arctanh, [0.], 2),
GradSpecialValuesTestSpec(lnp.sinc, [0.], 1)
]
def num_float_bits(dtype):
return lnp.finfo(dtypes.canonicalize_dtype(dtype)).bits
class NumpyGradTests(jtu.JaxTestCase):
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
rec.name, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
"order": rec.order, "tol": rec.tol}
for shapes in CombosWithReplacement(nonempty_shapes, rec.nargs)
for dtype in rec.dtypes)
for rec in GRAD_TEST_RECORDS))
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
rng = rng_factory()
tol = {onp.float32: 1e-1, onp.complex64: 1e-1}
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
"op": rec.op, "special_value": special_value, "order": rec.order}
for special_value in rec.values)
for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS))
def testOpGradSpecialValue(self, op, special_value, order):
check_grads(op, (special_value,), order, ["fwd", "rev"],
atol={onp.float32: 3e-3})
def testTakeAlongAxisIssue1521(self):
# https://github.com/google/jax/issues/1521
idx = lnp.repeat(lnp.arange(3), 10).reshape((30, 1))
def f(x):
y = x * lnp.arange(3.).reshape((1, 3))
return lnp.take_along_axis(y, idx, -1).sum()
check_grads(f, (1.,), order=1)
if __name__ == "__main__":
absltest.main()