mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
1813 lines
82 KiB
Python
1813 lines
82 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import functools
|
|
from functools import partial
|
|
import itertools
|
|
import operator
|
|
import unittest
|
|
from unittest import SkipTest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import six
|
|
|
|
import numpy as onp
|
|
|
|
import jax.ops
|
|
from jax import api
|
|
from jax import lax
|
|
from jax import numpy as lnp
|
|
from jax import test_util as jtu
|
|
from jax.lib import xla_bridge
|
|
|
|
from jax.config import config
|
|
config.parse_flags_with_absl()
|
|
FLAGS = config.FLAGS
|
|
|
|
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
|
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
|
empty_array_shapes = [(0,), (0, 4), (3, 0),]
|
|
|
|
scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
|
|
array_shapes = nonempty_array_shapes + empty_array_shapes
|
|
nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes
|
|
nonempty_shapes = scalar_shapes + nonempty_array_shapes
|
|
all_shapes = scalar_shapes + array_shapes
|
|
|
|
float_dtypes = [onp.float32, onp.float64]
|
|
complex_dtypes = [onp.complex64, onp.complex128]
|
|
int_dtypes = [onp.int32, onp.int64]
|
|
unsigned_dtypes = [onp.uint32, onp.uint64]
|
|
bool_dtypes = [onp.bool_]
|
|
default_dtypes = float_dtypes + int_dtypes
|
|
inexact_dtypes = float_dtypes + complex_dtypes
|
|
number_dtypes = float_dtypes + complex_dtypes + int_dtypes
|
|
all_dtypes = number_dtypes + bool_dtypes
|
|
|
|
OpRecord = collections.namedtuple(
|
|
"OpRecord",
|
|
["name", "nargs", "dtypes", "shapes", "rng", "diff_modes", "test_name",
|
|
"check_dtypes"])
|
|
|
|
|
|
def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None,
|
|
check_dtypes=True):
|
|
test_name = test_name or name
|
|
return OpRecord(name, nargs, dtypes, shapes, rng, diff_modes, test_name,
|
|
check_dtypes)
|
|
|
|
JAX_ONE_TO_ONE_OP_RECORDS = [
|
|
op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("add", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal(), []),
|
|
op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("float_power", 2, inexact_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("greater", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
|
|
op_record("greater_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
|
|
op_record("less", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
|
|
op_record("less_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
|
|
op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
|
|
op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
|
|
op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool(), []),
|
|
op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
|
|
op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
|
|
op_record("maximum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
|
|
op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
|
|
op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]),
|
|
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]),
|
|
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("tan", 1, number_dtypes, all_shapes, jtu.rand_uniform(-1.5, 1.5),
|
|
["rev"]),
|
|
op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("arcsin", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
|
|
op_record("arccos", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
|
|
op_record("arctan", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
|
|
op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
|
|
op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
|
|
op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
|
|
op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
|
|
]
|
|
|
|
JAX_COMPOUND_OP_RECORDS = [
|
|
# angle has inconsistent 32/64-bit return types across numpy versions.
|
|
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default(), [],
|
|
check_dtypes=False),
|
|
op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
|
|
op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive(), [],
|
|
test_name="expm1_large"),
|
|
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []),
|
|
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
|
|
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
|
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
|
|
op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
|
|
op_record("isfinite", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan(), []),
|
|
op_record("isinf", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan(), []),
|
|
op_record("isnan", 1, inexact_dtypes, all_shapes, jtu.rand_some_inf_and_nan(), []),
|
|
op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan(), []),
|
|
op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan(), []),
|
|
op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
|
|
op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
|
|
op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
|
|
op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
|
|
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive(), [],
|
|
test_name="log1p_large"),
|
|
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []),
|
|
op_record("logaddexp", 2, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("logaddexp2", 2, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes, jtu.rand_default(), []),
|
|
op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
|
|
op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
|
|
op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
op_record("sinc", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
|
|
op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
|
|
op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
|
|
op_record("where", 3, (onp.float32, onp.int64), all_shapes, jtu.rand_some_zero(), []),
|
|
op_record("diff", 1, number_dtypes, nonzerodim_shapes, jtu.rand_default(), ["rev"]),
|
|
]
|
|
|
|
JAX_BITWISE_OP_RECORDS = [
|
|
op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
|
|
jtu.rand_bool(), []),
|
|
op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
|
|
jtu.rand_bool(), []),
|
|
op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes,
|
|
jtu.rand_bool(), []),
|
|
op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
|
|
jtu.rand_bool(), []),
|
|
]
|
|
|
|
JAX_REDUCER_RECORDS = [
|
|
op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
|
op_record("prod", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []),
|
|
op_record("sum", 1, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("var", 1, number_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
|
op_record("std", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
|
]
|
|
|
|
JAX_REDUCER_NO_DTYPE_RECORDS = [
|
|
op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero(), []),
|
|
op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero(), []),
|
|
op_record("max", 1, all_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
|
op_record("min", 1, all_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
|
]
|
|
|
|
JAX_ARGMINMAX_RECORDS = [
|
|
op_record("argmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_equal(), []),
|
|
op_record("argmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_equal(), []),
|
|
]
|
|
|
|
JAX_OPERATOR_OVERLOADS = [
|
|
op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), []),
|
|
op_record("__mod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
op_record("__floordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
# TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2
|
|
op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default(), []),
|
|
# TODO(mattjj): investigate these failures
|
|
# op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool(), []),
|
|
# op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
# op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool(), []),
|
|
# op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
# TODO(mattjj): lshift, rshift
|
|
]
|
|
|
|
JAX_RIGHT_OPERATOR_OVERLOADS = [
|
|
op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), []),
|
|
op_record("__rmod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
op_record("__rfloordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
# op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool(), []),
|
|
# op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
|
|
# op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool(), []),
|
|
# op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
]
|
|
|
|
numpy_version = tuple(map(int, onp.version.version.split('.')))
|
|
if numpy_version >= (1, 15):
|
|
JAX_COMPOUND_OP_RECORDS += [
|
|
op_record("isclose", 2, all_dtypes, all_shapes, jtu.rand_small_positive(), []),
|
|
op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default(), []),
|
|
op_record("lcm", 2, int_dtypes, all_shapes, jtu.rand_default(), []),
|
|
]
|
|
JAX_REDUCER_NO_DTYPE_RECORDS += [
|
|
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
|
]
|
|
|
|
if six.PY2:
|
|
JAX_OPERATOR_OVERLOADS += [
|
|
op_record("__div__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
]
|
|
JAX_RIGHT_OPERATOR_OVERLOADS += [
|
|
op_record("__rdiv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
|
|
]
|
|
|
|
|
|
CombosWithReplacement = itertools.combinations_with_replacement
|
|
|
|
|
|
def _dtypes_are_compatible_for_bitwise_ops(args):
|
|
if len(args) <= 1:
|
|
return True
|
|
is_signed = lambda dtype: onp.issubdtype(dtype, onp.signedinteger)
|
|
width = lambda dtype: onp.iinfo(dtype).bits
|
|
x, y = args
|
|
if width(x) > width(y):
|
|
x, y = y, x
|
|
# The following condition seems a little ad hoc, but seems to capture what
|
|
# numpy actually implements.
|
|
return (
|
|
is_signed(x) == is_signed(y)
|
|
or (width(x) == 32 and width(y) == 32)
|
|
or (width(x) == 32 and width(y) == 64 and is_signed(y)))
|
|
|
|
def _shapes_are_broadcast_compatible(shapes):
|
|
accumulator = onp.zeros([])
|
|
for shape in shapes:
|
|
try:
|
|
accumulator = accumulator + onp.zeros(shape)
|
|
except ValueError:
|
|
return False
|
|
return True
|
|
|
|
def _shapes_are_equal_length(shapes):
|
|
return all(len(shape) == len(shapes[0]) for shape in shapes[1:])
|
|
|
|
|
|
class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|
"""Tests for LAX-backed Numpy implementation."""
|
|
|
|
def _GetArgsMaker(self, rng, shapes, dtypes):
|
|
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
|
dtypes),
|
|
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
|
|
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
|
|
"check_dtypes": rec.check_dtypes}
|
|
for shapes in filter(
|
|
_shapes_are_broadcast_compatible,
|
|
CombosWithReplacement(rec.shapes, rec.nargs))
|
|
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
|
|
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
|
|
JAX_COMPOUND_OP_RECORDS)))
|
|
def testOp(self, onp_op, lnp_op, rng, shapes, dtypes, check_dtypes):
|
|
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
|
py_scalar_arg = jtu.PYTHON_SCALAR_SHAPE in shapes
|
|
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker,
|
|
check_dtypes=check_dtypes and not py_scalar_arg)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=check_dtypes)
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
|
dtypes),
|
|
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "name": rec.name}
|
|
for shapes in filter(
|
|
_shapes_are_broadcast_compatible,
|
|
CombosWithReplacement(rec.shapes, rec.nargs))
|
|
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
|
|
for rec in JAX_OPERATOR_OVERLOADS))
|
|
def testOperatorOverload(self, name, rng, shapes, dtypes):
|
|
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
|
fun = lambda *xs: getattr(operator, name.strip('_'))(*xs)
|
|
self._CompileAndCheck(fun, args_maker,
|
|
check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes)
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
|
dtypes),
|
|
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "name": rec.name}
|
|
for shapes in filter(
|
|
_shapes_are_broadcast_compatible,
|
|
CombosWithReplacement(rec.shapes, rec.nargs))
|
|
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
|
|
for rec in JAX_RIGHT_OPERATOR_OVERLOADS))
|
|
def testRightOperatorOverload(self, name, rng, shapes, dtypes):
|
|
if shapes[1] is jtu.PYTHON_SCALAR_SHAPE:
|
|
raise SkipTest() # TODO(mattjj): clean up
|
|
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
|
fun = lambda fst, snd: getattr(snd, name)(fst)
|
|
self._CompileAndCheck(fun, args_maker,
|
|
check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes)
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": jtu.format_test_name_suffix(
|
|
rec.test_name, shapes, dtypes),
|
|
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
|
|
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
|
|
for shapes in filter(
|
|
_shapes_are_broadcast_compatible,
|
|
CombosWithReplacement(rec.shapes, rec.nargs))
|
|
for dtypes in filter(
|
|
_dtypes_are_compatible_for_bitwise_ops,
|
|
CombosWithReplacement(rec.dtypes, rec.nargs)))
|
|
for rec in JAX_BITWISE_OP_RECORDS))
|
|
def testBitwiseOp(self, onp_op, lnp_op, rng, shapes, dtypes):
|
|
if not FLAGS.jax_enable_x64 and any(
|
|
onp.iinfo(dtype).bits == 64 for dtype in dtypes):
|
|
self.skipTest("x64 types are disabled by jax_enable_x64")
|
|
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
|
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker,
|
|
check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
|
|
rec.test_name.capitalize(),
|
|
jtu.format_shape_dtype_string(shape, dtype), axis,
|
|
"None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims),
|
|
"rng": rec.rng, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
|
|
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
|
|
"axis": axis, "keepdims": keepdims}
|
|
for rec in JAX_REDUCER_RECORDS
|
|
for shape in rec.shapes for dtype in rec.dtypes
|
|
for out_dtype in [None] + rec.dtypes
|
|
for axis in set(range(-len(shape), len(shape))) | set([None])
|
|
for keepdims in [False, True]))
|
|
def testReducer(self, onp_op, lnp_op, rng, shape, dtype, out_dtype, axis, keepdims):
|
|
onp_fun = lambda x: onp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
|
|
lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
|
|
rec.test_name.capitalize(),
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
|
|
"rng": rec.rng, "shape": shape, "dtype": dtype,
|
|
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
|
|
"axis": axis, "keepdims": keepdims}
|
|
for rec in JAX_REDUCER_NO_DTYPE_RECORDS
|
|
for shape in rec.shapes for dtype in rec.dtypes
|
|
for axis in set(range(-len(shape), len(shape))) | set([None])
|
|
for keepdims in [False, True]))
|
|
def testReducerNoDtype(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims):
|
|
onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
|
|
lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_axis={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis),
|
|
"shape": shape, "dtype": dtype, "axis": axis}
|
|
for shape in all_shapes for dtype in all_dtypes
|
|
for axis in set(range(-len(shape), len(shape))) | set([None])))
|
|
def testCountNonzero(self, shape, dtype, axis):
|
|
rng = jtu.rand_some_zero()
|
|
onp_fun = lambda x: onp.count_nonzero(x, axis)
|
|
lnp_fun = lambda x: lnp.count_nonzero(x, axis)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "{}_inshape={}_axis={}".format(
|
|
rec.test_name.capitalize(),
|
|
jtu.format_shape_dtype_string(shape, dtype), axis),
|
|
"rng": rec.rng, "shape": shape, "dtype": dtype,
|
|
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
|
|
"axis": axis}
|
|
for rec in JAX_ARGMINMAX_RECORDS
|
|
for shape in rec.shapes for dtype in rec.dtypes
|
|
for axis in range(-len(shape), len(shape))))
|
|
def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis):
|
|
if (dtype == onp.complex128 and FLAGS.jax_test_dut and
|
|
FLAGS.jax_test_dut.startswith("gpu")):
|
|
raise unittest.SkipTest("complex128 reductions not supported on GPU")
|
|
|
|
def onp_fun(array_to_reduce):
|
|
return onp_op(array_to_reduce, axis)
|
|
|
|
def lnp_fun(array_to_reduce):
|
|
return lnp_op(array_to_reduce, axis)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}_{}".format(
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
|
|
axes),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
"axes": axes, "rng": rng}
|
|
for rng in [jtu.rand_default()]
|
|
for lhs_shape, rhs_shape, axes in [
|
|
[(2,), (2,), (-1, -1, -1, None)], # scalar output
|
|
[(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors
|
|
[(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors
|
|
[(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting
|
|
[(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes
|
|
[(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting
|
|
[(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors
|
|
[(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting
|
|
[(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing
|
|
[(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before
|
|
]
|
|
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
|
|
def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng):
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
|
axisa, axisb, axisc, axis = axes
|
|
lnp_fun = lambda a, b: lnp.cross(a, b, axisa, axisb, axisc, axis)
|
|
onp_fun = lambda a, b: onp.cross(a, b, axisa, axisb, axisc, axis)
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}_{}".format(
|
|
name,
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
"rng": rng}
|
|
for rng in [jtu.rand_default()]
|
|
for name, lhs_shape, rhs_shape in [
|
|
("matrix-scalar", (3, 3), ()),
|
|
("scalar-matrix", (), (3, 3)),
|
|
("matrix-vector", (4, 5), (5,)),
|
|
("vector-matrix", (6,), (6, 4)),
|
|
("matrix-matrix", (3, 4), (4, 5)),
|
|
("tensor-vector", (4, 3, 2), (2,)),
|
|
("vector-tensor", (2,), (3, 2, 4)),
|
|
("tensor-matrix", (4, 3, 2), (2, 5)),
|
|
("matrix-tensor", (5, 2), (3, 2, 4)),
|
|
("tensor-tensor", (2, 3, 4), (5, 4, 1))]
|
|
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
|
|
def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
|
self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}_{}".format(
|
|
name,
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
"rng": rng}
|
|
for rng in [jtu.rand_default()]
|
|
for name, lhs_shape, rhs_shape in [
|
|
("vector-vector", (3,), (3,)),
|
|
("matrix-vector", (3, 3), (3,)),
|
|
("vector-matrix", (3,), (3, 3)),
|
|
("matrix-matrix", (3, 3), (3, 3)),
|
|
("vector-tensor", (3,), (5, 3, 2)),
|
|
("tensor-vector", (5, 3, 2), (2,)),
|
|
("matrix-tensor", (5, 2), (3, 2, 4)),
|
|
("tensor-matrix", (5, 2, 3), (3, 2)),
|
|
("tensor-tensor", (5, 3, 4), (5, 4, 1)),
|
|
("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
|
|
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
|
|
def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
|
self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker,
|
|
check_dtypes=True)
|
|
self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}_{}".format(
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
|
|
axes),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
"axes": axes, "rng": rng}
|
|
for rng in [jtu.rand_default()]
|
|
for lhs_shape, rhs_shape, axes in [
|
|
[(2, 3, 4), (5, 6, 7), 0], # from issue #740
|
|
[(2, 3, 4), (3, 4, 5, 6), 2],
|
|
[(2, 3, 4), (5, 4, 3, 6), [1, 2]],
|
|
[(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
|
|
[(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
|
|
]
|
|
for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2)))
|
|
def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng):
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
|
lnp_fun = lambda a, b: lnp.tensordot(a, b, axes)
|
|
onp_fun = lambda a, b: onp.tensordot(a, b, axes)
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}".format(
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
"rng": jtu.rand_default()}
|
|
# TODO(phawkins): support integer dtypes too.
|
|
for lhs_dtype, rhs_dtype in CombosWithReplacement(inexact_dtypes, 2)
|
|
for lhs_shape, rhs_shape in [
|
|
(l, r) for l, r in CombosWithReplacement(all_shapes, 2)
|
|
if len(jtu._dims_of_shape(l)) == 0
|
|
or len(jtu._dims_of_shape(r)) == 0
|
|
or l[-1] == r[-1]]))
|
|
def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
|
|
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
|
onp_fun = lambda lhs, rhs: onp.inner(lhs, rhs)
|
|
lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs)
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_amin={}_amax={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
|
|
"shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max,
|
|
"rng": jtu.rand_default()}
|
|
for shape in all_shapes for dtype in number_dtypes
|
|
for a_min, a_max in [(-1, None), (None, 1), (-1, 1)]))
|
|
def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng):
|
|
onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
|
|
lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
# TODO(phawkins): the promotion behavior changed in Numpy 1.17.
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_decimals={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), decimals),
|
|
"shape": shape, "dtype": dtype, "decimals": decimals,
|
|
"rng": jtu.rand_default()}
|
|
for shape in all_shapes for dtype in number_dtypes
|
|
for decimals in [0, 1, -2]))
|
|
def testRoundStaticDecimals(self, shape, dtype, decimals, rng):
|
|
if onp.issubdtype(dtype, onp.integer) and decimals < 0:
|
|
self.skipTest("Integer rounding with decimals < 0 not implemented")
|
|
onp_fun = lambda x: onp.round(x, decimals=decimals)
|
|
lnp_fun = lambda x: lnp.round(x, decimals=decimals)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_mode={}_rpadwidth={}_rconstantvalues={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width_rank,
|
|
constant_values_rank),
|
|
"shape": shape, "dtype": dtype, "mode": mode,
|
|
"pad_width_rank": pad_width_rank,
|
|
"constant_values_rank": constant_values_rank, "rng": jtu.rand_default(),
|
|
"irng": jtu.rand_int(3)}
|
|
for mode, constant_values_rank, shapes in [
|
|
('constant', 0, all_shapes),
|
|
('constant', 1, all_shapes),
|
|
('constant', 2, all_shapes),
|
|
('symmetric', None, nonempty_shapes),
|
|
('reflect', None, nonempty_shapes),
|
|
('wrap', None, nonempty_shapes),
|
|
]
|
|
for shape in shapes for dtype in all_dtypes
|
|
for pad_width_rank in range(3)))
|
|
def testPad(self, shape, dtype, mode, pad_width_rank, constant_values_rank,
|
|
rng, irng):
|
|
pad_width = irng([len(shape), 2][2 - pad_width_rank:], onp.int32)
|
|
def onp_fun(x, kwargs):
|
|
if pad_width.size == 0:
|
|
return x
|
|
return onp.pad(x, pad_width, mode=mode, **kwargs)
|
|
def lnp_fun(x, kwargs):
|
|
return lnp.pad(x, pad_width, mode=mode, **kwargs)
|
|
|
|
def args_maker():
|
|
kwargs = {}
|
|
if constant_values_rank:
|
|
kwargs["constant_values"] = rng(
|
|
[len(shape), 2][2 - constant_values_rank:], dtype)
|
|
return rng(shape, dtype), kwargs
|
|
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape=[{}]_reps={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), reps),
|
|
"shape": shape, "dtype": dtype, "reps": reps,
|
|
"rng": jtu.rand_default()}
|
|
for reps in [(), (2,), (3, 4), (2, 3, 4)]
|
|
for dtype in default_dtypes
|
|
for shape in all_shapes
|
|
))
|
|
def testTile(self, shape, dtype, reps, rng):
|
|
onp_fun = lambda arg: onp.tile(arg, reps)
|
|
lnp_fun = lambda arg: lnp.tile(arg, reps)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
|
|
axis, ",".join(str(d) for d in base_shape),
|
|
",".join(onp.dtype(dtype).name for dtype in dtypes)),
|
|
"axis": axis, "base_shape": base_shape, "dtypes": dtypes,
|
|
"rng": jtu.rand_default()}
|
|
for num_arrs in [3]
|
|
for dtypes in CombosWithReplacement(default_dtypes, num_arrs)
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
|
for axis in range(-len(base_shape)+1, len(base_shape))))
|
|
def testConcatenate(self, axis, base_shape, dtypes, rng):
|
|
wrapped_axis = axis % len(base_shape)
|
|
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)]
|
|
onp_fun = lambda *args: onp.concatenate(args, axis=axis)
|
|
lnp_fun = lambda *args: lnp.concatenate(args, axis=axis)
|
|
|
|
def args_maker():
|
|
return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
|
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
|
|
axis, ",".join(str(d) for d in base_shape),
|
|
",".join(onp.dtype(dtype).name for dtype in dtypes)),
|
|
"axis": axis, "base_shape": base_shape, "dtypes": dtypes,
|
|
"rng": jtu.rand_default()}
|
|
for dtypes in CombosWithReplacement(default_dtypes, 2)
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
|
for axis in range(-len(base_shape)+1, len(base_shape))))
|
|
def testAppend(self, axis, base_shape, dtypes, rng):
|
|
wrapped_axis = axis % len(base_shape)
|
|
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)]
|
|
onp_fun = lambda arr, values: onp.append(arr, values, axis=axis)
|
|
lnp_fun = lambda arr, values: lnp.append(arr, values, axis=axis)
|
|
|
|
def args_maker():
|
|
return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
|
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape=[{}]_axis={}_repeats={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, repeats),
|
|
"axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats,
|
|
"rng": jtu.rand_default()}
|
|
for repeats in [0, 1, 2]
|
|
for dtype in default_dtypes
|
|
for shape in all_shapes
|
|
for axis in [None] + list(range(-len(shape), len(shape)))))
|
|
def testRepeat(self, axis, shape, dtype, repeats, rng):
|
|
onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis)
|
|
lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format(
|
|
op, jtu.format_shape_dtype_string(shape, dtype), axis, out_dtype),
|
|
"axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
|
|
"rng": jtu.rand_default(), "lnp_op": getattr(lnp, op),
|
|
"onp_op": getattr(onp, op)}
|
|
for op in ["cumsum", "cumprod"]
|
|
for dtype in default_dtypes
|
|
for out_dtype in default_dtypes
|
|
for shape in all_shapes
|
|
for axis in [None] + list(range(-len(shape), len(shape)))))
|
|
def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng):
|
|
onp_fun = lambda arg: onp_op(arg, axis=axis, dtype=out_dtype)
|
|
lnp_fun = lambda arg: lnp_op(arg, axis=axis, dtype=out_dtype)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_dtype={}_m={}_n={}_k={}".format(
|
|
onp.dtype(dtype).name, m, n, k),
|
|
"m": m, "n": n, "k": k, "dtype": dtype, "rng": jtu.rand_default()}
|
|
for dtype in default_dtypes
|
|
for n in [0, 4]
|
|
for m in [None, 0, 1, 3, 4]
|
|
for k in list(range(-4, 4))))
|
|
def testTri(self, m, n, k, dtype, rng):
|
|
onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype)
|
|
lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype)
|
|
args_maker = lambda: []
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_op={}_shape={}_k={}".format(
|
|
op, jtu.format_shape_dtype_string(shape, dtype), k),
|
|
"dtype": dtype, "shape": shape, "op": op, "k": k,
|
|
"rng": jtu.rand_default()}
|
|
for dtype in default_dtypes
|
|
for shape in [shape for shape in all_shapes if len(shape) >= 2]
|
|
for op in ["tril", "triu"]
|
|
for k in list(range(-3, 3))))
|
|
def testTriLU(self, dtype, shape, op, k, rng):
|
|
onp_fun = lambda arg: getattr(onp, op)(arg, k=k)
|
|
lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_k={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), k),
|
|
"dtype": dtype, "shape": shape, "k": k, "rng": jtu.rand_default()}
|
|
for dtype in default_dtypes
|
|
for shape in [shape for shape in all_shapes if len(shape) in (1, 2)]
|
|
for k in list(range(-4, 4))))
|
|
def testDiag(self, shape, dtype, k, rng):
|
|
onp_fun = lambda arg: onp.diag(arg, k)
|
|
lnp_fun = lambda arg: lnp.diag(arg, k)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2),
|
|
"dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1,
|
|
"axis2": axis2, "rng": jtu.rand_default()}
|
|
for dtype in default_dtypes
|
|
for shape in [shape for shape in all_shapes if len(shape) >= 2]
|
|
for axis1 in range(-len(shape), len(shape))
|
|
for axis2 in [a for a in range(-len(shape), len(shape))
|
|
if a % len(shape) != axis1 % len(shape)]
|
|
for offset in list(range(-4, 4))))
|
|
def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng):
|
|
onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2)
|
|
lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n),
|
|
"dtype": dtype, "n": n}
|
|
for dtype in default_dtypes
|
|
for n in list(range(4))))
|
|
def testIdentity(self, n, dtype):
|
|
onp_fun = lambda: onp.identity(n, dtype)
|
|
lnp_fun = lambda: lnp.identity(n, dtype)
|
|
args_maker = lambda: []
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
out_dtype, offset, axis1, axis2),
|
|
"dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset,
|
|
"axis1": axis1, "axis2": axis2, "rng": jtu.rand_default()}
|
|
for dtype in default_dtypes
|
|
for out_dtype in [None] + number_dtypes
|
|
for shape in [shape for shape in all_shapes if len(shape) >= 2]
|
|
for axis1 in range(-len(shape), len(shape))
|
|
for axis2 in range(-len(shape), len(shape))
|
|
if (axis1 % len(shape)) != (axis2 % len(shape))
|
|
for offset in list(range(-4, 4))))
|
|
def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng):
|
|
onp_fun = lambda arg: onp.trace(arg, offset, axis1, axis2, out_dtype)
|
|
lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_axis={}".format(
|
|
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis),
|
|
"shape": shape, "axis": axis, "dtypes": dtypes, "rng": rng}
|
|
for dtypes in [
|
|
[onp.float32],
|
|
[onp.float32, onp.float32],
|
|
[onp.float32, onp.int32, onp.float32],
|
|
[onp.float32, onp.int64, onp.float32],
|
|
[onp.float32, onp.int32, onp.float64],
|
|
]
|
|
for shape in [(), (2,), (3, 4), (1, 100)]
|
|
for axis in range(-len(shape), len(shape) + 1)
|
|
for rng in [jtu.rand_default()]))
|
|
def testStack(self, shape, axis, dtypes, rng):
|
|
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
|
onp_fun = partial(onp.stack, axis=axis)
|
|
lnp_fun = partial(lnp.stack, axis=axis)
|
|
self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_op={}_{}".format(
|
|
op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
|
|
"shape": shape, "op": op, "dtypes": dtypes, "rng": rng}
|
|
for op in ["hstack", "vstack", "dstack"]
|
|
for dtypes in [
|
|
[onp.float32],
|
|
[onp.float32, onp.float32],
|
|
[onp.float32, onp.int32, onp.float32],
|
|
[onp.float32, onp.int64, onp.float32],
|
|
[onp.float32, onp.int32, onp.float64],
|
|
]
|
|
for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)]
|
|
for rng in [jtu.rand_default()]))
|
|
def testHVDStack(self, shape, op, dtypes, rng):
|
|
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
|
onp_fun = getattr(onp, op)
|
|
lnp_fun = getattr(lnp, op)
|
|
self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_inshape={}_outdtype={}".format(
|
|
jtu.format_shape_dtype_string(shape, fill_value_dtype),
|
|
onp.dtype(out_dtype).name if out_dtype else "None"),
|
|
"shape": shape, "fill_value_dtype": fill_value_dtype,
|
|
"out_dtype": out_dtype, "rng": jtu.rand_default()}
|
|
for shape in array_shapes
|
|
for fill_value_dtype in default_dtypes
|
|
for out_dtype in [None] + default_dtypes))
|
|
def testFull(self, shape, fill_value_dtype, out_dtype, rng):
|
|
onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
|
|
lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)
|
|
args_maker = lambda: [rng((), fill_value_dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format(
|
|
jtu.format_shape_dtype_string(shape, in_dtype),
|
|
onp.dtype(fill_value_dtype).name,
|
|
onp.dtype(out_dtype).name),
|
|
"shape": shape, "in_dtype": in_dtype,
|
|
"fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype,
|
|
"rng": jtu.rand_default()}
|
|
for shape in array_shapes
|
|
for in_dtype in default_dtypes
|
|
for fill_value_dtype in default_dtypes
|
|
for out_dtype in default_dtypes))
|
|
def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng):
|
|
onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype)
|
|
lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype)
|
|
args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_axis={}_{}sections".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
|
|
"shape": shape, "num_sections": num_sections, "axis": axis,
|
|
"dtype": dtype, "rng": jtu.rand_default()}
|
|
for shape, axis, num_sections in [
|
|
((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
|
|
((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
|
|
for dtype in default_dtypes))
|
|
def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng):
|
|
onp_fun = lambda x: onp.split(x, num_sections, axis=axis)
|
|
lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_axis={}_{}sections".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
|
|
"shape": shape, "num_sections": num_sections, "axis": axis,
|
|
"dtype": dtype, "rng": jtu.rand_default()}
|
|
for shape, axis, num_sections in [
|
|
((12, 4), 0, 4), ((12, 4), 1, 2),
|
|
((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)]
|
|
for dtype in default_dtypes))
|
|
def testHVDSplit(self, shape, num_sections, axis, dtype, rng):
|
|
def fn(module, axis):
|
|
if axis == 0:
|
|
return module.vsplit
|
|
elif axis == 1:
|
|
return module.hsplit
|
|
else:
|
|
assert axis == 2
|
|
return module.dsplit
|
|
|
|
onp_fun = lambda x: fn(onp, axis)(x, num_sections)
|
|
lnp_fun = lambda x: fn(lnp, axis)(x, num_sections)
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_inshape={}_outshape={}_order={}".format(
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
jtu.format_shape_dtype_string(out_shape, dtype),
|
|
order),
|
|
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
|
"order": order, "rng": jtu.rand_default()}
|
|
for dtype in default_dtypes
|
|
for order in ["C", "F"]
|
|
for arg_shape, out_shape in [
|
|
(jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)),
|
|
((), (1, 1, 1)),
|
|
((7, 0), (0, 42, 101)),
|
|
((3, 4), 12),
|
|
((3, 4), (12,)),
|
|
((3, 4), -1),
|
|
((2, 1, 4), (-1,)),
|
|
((2, 2, 4), (2, 8))
|
|
]))
|
|
def testReshape(self, arg_shape, out_shape, dtype, order, rng):
|
|
onp_fun = lambda x: onp.reshape(x, out_shape, order=order)
|
|
lnp_fun = lambda x: lnp.reshape(x, out_shape, order=order)
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_inshape={}_outshape={}".format(
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
jtu.format_shape_dtype_string(out_shape, dtype)),
|
|
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
|
"rng": jtu.rand_default()}
|
|
for dtype in default_dtypes
|
|
for arg_shape, out_shape in [
|
|
((7, 0), (0, 42, 101)),
|
|
((2, 1, 4), (-1,)),
|
|
((2, 2, 4), (2, 8))
|
|
]))
|
|
def testReshapeMethod(self, arg_shape, out_shape, dtype, rng):
|
|
onp_fun = lambda x: onp.reshape(x, out_shape)
|
|
lnp_fun = lambda x: x.reshape(*out_shape)
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_inshape={}_expanddim={}".format(
|
|
jtu.format_shape_dtype_string(arg_shape, dtype), dim),
|
|
"arg_shape": arg_shape, "dtype": dtype, "dim": dim,
|
|
"rng": jtu.rand_default()}
|
|
for arg_shape in [(), (3,), (3, 4)]
|
|
for dtype in default_dtypes
|
|
for dim in range(-len(arg_shape)+1, len(arg_shape))))
|
|
def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng):
|
|
onp_fun = lambda x: onp.expand_dims(x, dim)
|
|
lnp_fun = lambda x: lnp.expand_dims(x, dim)
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_inshape={}_axes=({},{})".format(
|
|
jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
|
|
"arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2,
|
|
"rng": jtu.rand_default()}
|
|
for arg_shape, ax1, ax2 in [
|
|
((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
|
|
((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
|
|
for dtype in default_dtypes))
|
|
def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng):
|
|
onp_fun = lambda x: onp.swapaxes(x, ax1, ax2)
|
|
lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2)
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_inshape={}_axis={}".format(
|
|
jtu.format_shape_dtype_string(arg_shape, dtype), ax),
|
|
"arg_shape": arg_shape, "dtype": dtype, "ax": ax,
|
|
"rng": jtu.rand_default()}
|
|
for arg_shape, ax in [
|
|
((3, 1), None),
|
|
((3, 1), 1),
|
|
((1, 3, 1), (0, 2)),
|
|
((1, 4, 1), (0,))]
|
|
for dtype in default_dtypes))
|
|
def testSqueeze(self, arg_shape, dtype, ax, rng):
|
|
onp_fun = lambda x: onp.squeeze(x, ax)
|
|
lnp_fun = lambda x: lnp.squeeze(x, ax)
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
axis,
|
|
(None if weights_shape == None else jtu.format_shape_dtype_string(weights_shape, dtype)),
|
|
returned),
|
|
"rng": jtu.rand_default(), "shape": shape, "dtype": dtype, "axis": axis,
|
|
"weights_shape": weights_shape, "returned": returned}
|
|
for shape in nonempty_shapes
|
|
for dtype in number_dtypes
|
|
for axis in set(range(-len(shape), len(shape))) | set([None])
|
|
# `weights_shape` is either `None`, same as the averaged axis, or same as
|
|
# that of the input
|
|
for weights_shape in ([None, shape] if axis is None else [None, (shape[axis],), shape])
|
|
for returned in [False, True]))
|
|
def testAverage(self, shape, dtype, axis, weights_shape, returned, rng):
|
|
onp_fun = lambda x, weights: onp.average(x, axis, weights, returned)
|
|
lnp_fun = lambda x, weights: lnp.average(x, axis, weights, returned)
|
|
args_maker = lambda: [rng(shape, dtype),
|
|
None if weights_shape is None else rng(weights_shape, dtype)]
|
|
|
|
try:
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
except ZeroDivisionError:
|
|
self.skipTest("don't support checking for ZeroDivisionError")
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_arg{}_ndmin={}".format(i, ndmin),
|
|
"arg": arg, "ndmin": ndmin}
|
|
for i, arg in enumerate([
|
|
3., [1, 2, 3], [1., 2., 3.],
|
|
[[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]],
|
|
[[3, onp.array(2), 1], onp.arange(3.)],
|
|
])
|
|
for ndmin in [None, onp.ndim(arg), onp.ndim(arg) + 1, onp.ndim(arg) + 2]))
|
|
def testArray(self, arg, ndmin):
|
|
args_maker = lambda: [arg]
|
|
if ndmin is not None:
|
|
onp_fun = partial(onp.array, ndmin=ndmin)
|
|
lnp_fun = partial(lnp.array, ndmin=ndmin)
|
|
else:
|
|
onp_fun = onp.array
|
|
lnp_fun = lnp.array
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
def testIssue121(self):
|
|
assert not onp.isscalar(lnp.array(3))
|
|
|
|
def testArrayMethod(self):
|
|
class arraylike(object):
|
|
dtype = onp.float32
|
|
def __array__(self, dtype=None):
|
|
return 3.
|
|
a = arraylike()
|
|
ans = lnp.array(a)
|
|
assert ans == 3.
|
|
|
|
@jtu.skip_on_devices("tpu") # TODO(b/32368900): TPUs don't support uint8 yet.
|
|
def testMemoryView(self):
|
|
ans = lnp.array(bytearray(b'\x2a'))
|
|
self.assertAllClose(
|
|
ans,
|
|
onp.array([0x2a], dtype=onp.uint8),
|
|
check_dtypes=True)
|
|
|
|
def testAllClose(self):
|
|
rng = onp.random.RandomState(0)
|
|
x = rng.randn(2, 2)
|
|
y = rng.randn(2)
|
|
|
|
def same(list1, list2):
|
|
allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3)
|
|
elements_close = list(map(allclose, list1, list2))
|
|
return lnp.all(lnp.array(elements_close))
|
|
|
|
csame = api.jit(same)
|
|
|
|
a1 = same((x, y), (x, y))
|
|
a2 = csame((x, y), (x, y))
|
|
a3 = csame((x, y), (x, 2 * y))
|
|
|
|
self.assertTrue(a1)
|
|
self.assertTrue(a2)
|
|
self.assertFalse(a3)
|
|
|
|
@jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure
|
|
def testOnesBroadcastingConstantHandler(self):
|
|
# TODO(mattjj): update this test for jax3
|
|
self.skipTest("test needs jax3 update")
|
|
|
|
def fun(x):
|
|
ones = lnp.ones((3, 4))
|
|
assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0)
|
|
|
|
# To check that the constant handler generates a Broadcast for stride-zero
|
|
# arrays, we monkey-patch the client instance.
|
|
# TODO(mattjj): once we have better HLO dumping and inspecting facilities,
|
|
# we can check the HLO more directly.
|
|
c = x._node.c
|
|
Broadcast = c.Broadcast # pylint: disable=invalid-name
|
|
was_called = []
|
|
c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args)
|
|
out = x + ones # the ndarray constant handler should call Broadcast here
|
|
assert was_called, "Broadcast was not called."
|
|
|
|
return out
|
|
|
|
fun = api.jit(fun)
|
|
out_val = fun(lnp.ones(4))
|
|
self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)
|
|
|
|
def testZeroStridesConstantHandler(self):
|
|
raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1)
|
|
const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))
|
|
|
|
def fun(x):
|
|
return x * const
|
|
|
|
fun = api.jit(fun)
|
|
out_val = fun(3.)
|
|
self.assertAllClose(out_val, 3. * const, check_dtypes=False)
|
|
|
|
def testIsInstanceNdarrayDuringTracing(self):
|
|
arr = onp.ones(3)
|
|
|
|
@api.jit
|
|
def f(x):
|
|
self.assertIsInstance(x, lnp.ndarray)
|
|
return lnp.sum(x)
|
|
|
|
f(arr)
|
|
|
|
|
|
def testNonArrayErrorMessage(self):
|
|
x = [1., 2.]
|
|
y = onp.array([3., 4.])
|
|
|
|
def g(x, y):
|
|
return lnp.add(x, y)
|
|
|
|
def f(x, y):
|
|
return lnp.dot(x, y)
|
|
|
|
self.assertRaises(TypeError, lambda: g(x, y))
|
|
self.assertRaises(TypeError, lambda: f(x, y))
|
|
self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
|
|
self.assertRaises(TypeError, lambda: api.jit(f)(x, y))
|
|
|
|
def testAbstractionErrorMessage(self):
|
|
|
|
@api.jit
|
|
def f(x, n):
|
|
for _ in range(n):
|
|
x = x * x
|
|
return x
|
|
|
|
self.assertRaises(TypeError, lambda: f(3., 3))
|
|
|
|
@api.jit
|
|
def g(x):
|
|
if x > 0.:
|
|
return x * 2
|
|
else:
|
|
return x + 2
|
|
|
|
self.assertRaises(TypeError, lambda: g(3.))
|
|
|
|
def testTracingPrimitiveWithNoTranslationErrorMessage(self):
|
|
# TODO(mattjj): update this for jax3
|
|
self.skipTest("test needs jax3 update")
|
|
foo = lnp._not_implemented(lambda x: x)
|
|
|
|
# No error if there's no tracing.
|
|
foo(onp.arange(3))
|
|
|
|
cfoo = api.jit(foo)
|
|
self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_axis={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), axis),
|
|
"rng": rng, "shape": shape, "dtype": dtype, "axis": axis}
|
|
for shape in [(3,), (2, 3)]
|
|
for dtype in default_dtypes
|
|
for axis in range(-len(shape), len(shape)) # Test negative axes
|
|
for rng in [jtu.rand_default()]))
|
|
def testFlip(self, shape, dtype, axis, rng):
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
|
lnp_op = lambda x: lnp.flip(x, axis)
|
|
onp_op = lambda x: onp.flip(x, axis)
|
|
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
"rng": rng, "shape": shape, "dtype": dtype}
|
|
for shape in [(3,), (2, 3), (3, 2, 4)]
|
|
for dtype in default_dtypes
|
|
for rng in [jtu.rand_default()]))
|
|
def testFlipud(self, shape, dtype, rng):
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
|
lnp_op = lambda x: lnp.flipud(x)
|
|
onp_op = lambda x: onp.flipud(x)
|
|
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
"rng": rng, "shape": shape, "dtype": dtype}
|
|
for shape in [(3, 2), (2, 3), (3, 2, 4)]
|
|
for dtype in default_dtypes
|
|
for rng in [jtu.rand_default()]))
|
|
def testFliplr(self, shape, dtype, rng):
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
|
lnp_op = lambda x: lnp.fliplr(x)
|
|
onp_op = lambda x: onp.fliplr(x)
|
|
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_k={}_axes={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), k, axes),
|
|
"rng": rng, "shape": shape, "dtype": dtype, "k": k, "axes": axes}
|
|
for shape, axes in [
|
|
[(2, 3), (0, 1)],
|
|
[(2, 3), (1, 0)],
|
|
[(4, 3, 2), (0, 2)],
|
|
[(4, 3, 2), (2, 1)],
|
|
]
|
|
for k in range(-3, 4)
|
|
for dtype in default_dtypes
|
|
for rng in [jtu.rand_default()]))
|
|
def testRot90(self, shape, dtype, k, axes, rng):
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
|
lnp_op = lambda x: lnp.rot90(x, k, axes)
|
|
onp_op = lambda x: onp.rot90(x, k, axes)
|
|
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
|
|
|
# TODO(mattjj): test infix operator overrides
|
|
|
|
def testRavel(self):
|
|
rng = onp.random.RandomState(0)
|
|
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
|
|
self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
|
|
|
|
def testAstype(self):
|
|
rng = onp.random.RandomState(0)
|
|
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
|
|
op = lambda x: x.astype(lnp.int32)
|
|
self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
# TODO(mattjj): test other ndarray-like method overrides
|
|
|
|
def testOnpMean(self):
|
|
# from https://github.com/google/jax/issues/125
|
|
x = lax.add(lnp.eye(3), 0.)
|
|
ans = onp.mean(x)
|
|
self.assertAllClose(ans, onp.array(1./3), check_dtypes=False)
|
|
|
|
def testArangeOnFloats(self):
|
|
# from https://github.com/google/jax/issues/145
|
|
expected = onp.arange(0.0, 1.0, 0.1)
|
|
ans = lnp.arange(0.0, 1.0, 0.1)
|
|
self.assertAllClose(expected, ans, check_dtypes=True)
|
|
|
|
def testSortManually(self):
|
|
# manual tests for sort are nice because we don't have to worry about ties.
|
|
# lax.sort is tested combinatorially.
|
|
ans = lnp.sort(onp.array([16, 15, 23, 42, 8, 4]))
|
|
expected = onp.array([4, 8, 15, 16, 23, 42])
|
|
self.assertAllClose(expected, ans, check_dtypes=True)
|
|
|
|
a = onp.array([[1, 4], [3, 1]])
|
|
ans = lnp.sort(a, axis=None)
|
|
expected = onp.array([1, 1, 3, 4])
|
|
self.assertAllClose(expected, ans, check_dtypes=True)
|
|
|
|
a = onp.array([[1, 4], [3, 1]])
|
|
ans = lnp.sort(a) # last axis
|
|
expected = onp.array([[1, 4], [1, 3]])
|
|
self.assertAllClose(expected, ans, check_dtypes=True)
|
|
|
|
a = onp.array([[1, 4], [3, 1]])
|
|
ans = lnp.sort(a, axis=0)
|
|
expected = onp.array([[1, 1], [3, 4]])
|
|
self.assertAllClose(expected, ans, check_dtypes=True)
|
|
|
|
def testArgsortManually(self):
|
|
x = onp.array([16, 15, 23, 42, 8, 4])
|
|
ans = lnp.argsort(x)
|
|
expected = onp.argsort(x)
|
|
self.assertAllClose(expected, ans, check_dtypes=False)
|
|
|
|
x = onp.array([[16, 15, 23], [42, 8, 4]])
|
|
ans = lnp.argsort(x, axis=0)
|
|
expected = onp.argsort(x, axis=0)
|
|
self.assertAllClose(expected, ans, check_dtypes=False)
|
|
|
|
x = onp.array([[16, 15, 23], [42, 8, 4]])
|
|
ans = lnp.argsort(x, axis=1)
|
|
expected = onp.argsort(x, axis=1)
|
|
self.assertAllClose(expected, ans, check_dtypes=False)
|
|
|
|
x = onp.array([[16, 15, 23], [42, 8, 4]])
|
|
ans = lnp.argsort(x, axis=None)
|
|
expected = onp.argsort(x, axis=None)
|
|
self.assertAllClose(expected, ans, check_dtypes=False)
|
|
|
|
x = onp.array([[16, 15, 23], [42, 8, 4]])
|
|
ans = lnp.argsort(x)
|
|
expected = onp.argsort(x)
|
|
self.assertAllClose(expected, ans, check_dtypes=False)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_shifts={}_axis={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
shifts, axis),
|
|
"rng": rng, "shape": shape, "dtype": dtype, "shifts": shifts,
|
|
"axis": axis}
|
|
for dtype in all_dtypes
|
|
for shape in [(3, 4), (3, 4, 5), (7, 4, 0)]
|
|
for shifts, axis in [
|
|
(3, None),
|
|
(1, 1),
|
|
((3,), (0,)),
|
|
((-2,), (-2,)),
|
|
((1, 2), (0, -1))
|
|
]
|
|
for rng in [jtu.rand_default()]))
|
|
def testRoll(self, shape, dtype, shifts, axis, rng):
|
|
args_maker = lambda: [rng(shape, dtype), onp.array(shifts)]
|
|
lnp_op = partial(lnp.roll, axis=axis)
|
|
onp_op = partial(onp.roll, axis=axis)
|
|
self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_index={}_axis={}_mode={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
jtu.format_shape_dtype_string(index_shape, index_dtype),
|
|
axis, mode),
|
|
"rng": rng, "rng_indices": rng_indices, "shape": shape,
|
|
"index_shape": index_shape, "dtype": dtype, "index_dtype": index_dtype,
|
|
"axis": axis, "mode": mode}
|
|
for shape in [(3,), (3, 4), (3, 4, 5)]
|
|
for index_shape in scalar_shapes + [(3,), (2, 1, 3)]
|
|
for axis in itertools.chain(range(-len(shape), len(shape)), [None])
|
|
for dtype in all_dtypes
|
|
for index_dtype in int_dtypes
|
|
for mode in ['wrap', 'clip']
|
|
for rng in [jtu.rand_default()]
|
|
for rng_indices in [jtu.rand_int(-5, 5)]))
|
|
def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode, rng,
|
|
rng_indices):
|
|
def args_maker():
|
|
x = rng(shape, dtype)
|
|
i = rng_indices(index_shape, index_dtype)
|
|
return x, i
|
|
|
|
lnp_op = lambda x, i: lnp.take(x, i, axis=axis, mode=mode)
|
|
onp_op = lambda x, i: onp.take(x, i, axis=axis, mode=mode)
|
|
self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_ishape={}_axis={}".format(
|
|
jtu.format_shape_dtype_string(x_shape, dtype), i_shape, axis),
|
|
"rng": rng, "x_shape": x_shape, "i_shape": i_shape, "dtype": dtype,
|
|
"axis": axis}
|
|
for x_shape, i_shape in filter(
|
|
_shapes_are_equal_length,
|
|
filter(_shapes_are_broadcast_compatible,
|
|
CombosWithReplacement(nonempty_nonscalar_array_shapes, 2)))
|
|
for axis in itertools.chain(range(len(x_shape)), [-1], [None])
|
|
for dtype in default_dtypes
|
|
for rng in [jtu.rand_default()]))
|
|
def testTakeAlongAxis(self, x_shape, i_shape, dtype, axis, rng):
|
|
i_shape = onp.array(i_shape)
|
|
if axis is None:
|
|
i_shape = [onp.prod(i_shape, dtype=onp.int64)]
|
|
else:
|
|
# Test the case where the size of the axis doesn't necessarily broadcast.
|
|
i_shape[axis] *= 3
|
|
i_shape = list(i_shape)
|
|
def args_maker():
|
|
x = rng(x_shape, dtype)
|
|
n = onp.prod(x_shape, dtype=onp.int32) if axis is None else x_shape[axis]
|
|
i = rng(i_shape, onp.int32) % (2 * n - 1) - (n - 1)
|
|
return x, i
|
|
|
|
lnp_op = lambda x, i: lnp.take_along_axis(x, i, axis=axis)
|
|
|
|
if hasattr(onp, "take_along_axis"):
|
|
onp_op = lambda x, i: onp.take_along_axis(x, i, axis=axis)
|
|
self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_n={}_increasing={}".format(
|
|
jtu.format_shape_dtype_string([shape], dtype),
|
|
n, increasing),
|
|
"dtype": dtype, "shape": shape, "n": n, "increasing": increasing,
|
|
"rng": jtu.rand_default()}
|
|
for dtype in inexact_dtypes
|
|
for shape in [0, 5]
|
|
for n in [2, 4]
|
|
for increasing in [False, True]))
|
|
def testVander(self, shape, dtype, n, increasing, rng):
|
|
onp_fun = lambda arg: onp.vander(arg, N=n, increasing=increasing)
|
|
lnp_fun = lambda arg: lnp.vander(arg, N=n, increasing=increasing)
|
|
args_maker = lambda: [rng([shape], dtype)]
|
|
# np.vander seems to return float64 for all floating types. We could obey
|
|
# those semantics, but they seem like a bug.
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": jtu.format_test_name_suffix("nan_to_num", [shape],
|
|
[dtype]),
|
|
"rng": jtu.rand_some_inf_and_nan(), "shape": shape, "dtype": dtype}
|
|
for shape in all_shapes
|
|
for dtype in inexact_dtypes))
|
|
def testNanToNum(self, rng, shape, dtype):
|
|
dtype = onp.dtype(xla_bridge.canonicalize_dtype(dtype)).type
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
self._CheckAgainstNumpy(onp.nan_to_num, lnp.nan_to_num, args_maker,
|
|
check_dtypes=True)
|
|
self._CompileAndCheck(lnp.nan_to_num, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes),
|
|
"rng": jtu.rand_default(), "shapes": shapes, "dtypes": dtypes}
|
|
for shapes, dtypes in (
|
|
((), ()),
|
|
(((7,),), (onp.int32,)),
|
|
(((3,), (4,)), (onp.int32, onp.int32)),
|
|
(((3,), (1,), (4,)), (onp.int32, onp.int32, onp.int32)),
|
|
)))
|
|
def testIx_(self, rng, shapes, dtypes):
|
|
args_maker = lambda: [rng(shape, dtype)
|
|
for shape, dtype in zip(shapes, dtypes)]
|
|
self._CheckAgainstNumpy(onp.ix_, lnp.ix_, args_maker,
|
|
check_dtypes=True)
|
|
self._CompileAndCheck(lnp.ix_, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}".format(
|
|
op,
|
|
jtu.format_shape_dtype_string(a_shape, a_dtype),
|
|
jtu.format_shape_dtype_string(q_shape, q_dtype),
|
|
axis, keepdims),
|
|
"a_rng": jtu.rand_default(), "q_rng": q_rng, "op": op,
|
|
"a_shape": a_shape, "a_dtype": a_dtype,
|
|
"q_shape": q_shape, "q_dtype": q_dtype, "axis": axis,
|
|
"keepdims": keepdims}
|
|
for (op, q_rng) in (
|
|
("percentile", jtu.rand_uniform(low=0., high=100.)),
|
|
("quantile", jtu.rand_uniform(low=0., high=1.)),
|
|
("median", jtu.rand_uniform(low=0., high=1.)),
|
|
)
|
|
for a_dtype in float_dtypes
|
|
for a_shape, axis in (
|
|
((7,), None),
|
|
((47, 7), 0),
|
|
((4, 101), 1),
|
|
)
|
|
for q_dtype in [onp.float32]
|
|
for q_shape in scalar_shapes + [(4,)]
|
|
for keepdims in [False, True]))
|
|
@jtu.skip_on_devices("tpu") # TODO(phawkins): investigate this failure
|
|
def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype,
|
|
axis, keepdims):
|
|
if op == "quantile" and numpy_version < (1, 15):
|
|
raise SkipTest("Numpy < 1.15 does not have np.quantile")
|
|
if op == "median":
|
|
args_maker = lambda: [a_rng(a_shape, a_dtype)]
|
|
else:
|
|
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
|
|
onp_fun = partial(getattr(onp, op), axis=axis, keepdims=keepdims)
|
|
lnp_fun = partial(getattr(lnp, op), axis=axis, keepdims=keepdims)
|
|
# TODO(phawkins): we currently set dtype=False because we aren't as
|
|
# aggressive about promoting to float64. It's not clear we want to mimic
|
|
# Numpy here.
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": jtu.format_test_name_suffix("select", shapes,
|
|
(onp.bool_,) * n + dtypes),
|
|
"rng": jtu.rand_default(), "shapes": shapes, "dtypes": dtypes}
|
|
for n in range(0, 3)
|
|
for shapes in filter(
|
|
_shapes_are_broadcast_compatible,
|
|
CombosWithReplacement(all_shapes, 2 * n + 1))
|
|
for dtypes in CombosWithReplacement(all_dtypes, n + 1)))
|
|
def test(self, rng, shapes, dtypes):
|
|
n = len(dtypes) - 1
|
|
def args_maker():
|
|
condlist = [rng(shape, onp.bool_) for shape in shapes[:n]]
|
|
choicelist = [rng(shape, dtype)
|
|
for shape, dtype in zip(shapes[n:-1], dtypes[:n])]
|
|
default = rng(shapes[-1], dtypes[-1])
|
|
return condlist, choicelist, default
|
|
self._CheckAgainstNumpy(onp.select, lnp.select, args_maker,
|
|
check_dtypes=True)
|
|
self._CompileAndCheck(lnp.select, args_maker, check_dtypes=True)
|
|
|
|
|
|
def testIssue330(self):
|
|
x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash
|
|
self.assertEqual(x[0, 0], 1)
|
|
|
|
def testScalarDtypePromotion(self):
|
|
# disabled this test after https://github.com/google/jax/issues/732
|
|
msg = ("jax.numpy differs from numpy in promotion rules for Python scalars."
|
|
" See https://github.com/google/jax/issues/732.")
|
|
raise SkipTest(msg)
|
|
orig_numpy_result = (1 + onp.eye(1, dtype=onp.float32)).dtype
|
|
jax_numpy_result = (1 + lnp.eye(1, dtype=lnp.float32)).dtype
|
|
self.assertEqual(orig_numpy_result, jax_numpy_result)
|
|
|
|
def testSymmetrizeDtypePromotion(self):
|
|
x = onp.eye(3, dtype=onp.float32)
|
|
orig_numpy_result = ((x + x.T) / 2).dtype
|
|
|
|
x = lnp.eye(3, dtype=lnp.float32)
|
|
jax_numpy_result = ((x + x.T) / 2).dtype
|
|
self.assertEqual(orig_numpy_result, jax_numpy_result)
|
|
|
|
def testIssue347(self):
|
|
# https://github.com/google/jax/issues/347
|
|
def test_fail(x):
|
|
x = lnp.sqrt(lnp.sum(x ** 2, axis=1))
|
|
ones = lnp.ones_like(x)
|
|
x = lnp.where(x > 0.5, x, ones)
|
|
return lnp.sum(x)
|
|
|
|
x = lnp.array([[1, 2], [3, 4], [0, 0]], dtype=lnp.float64)
|
|
result = api.grad(test_fail)(x)
|
|
assert not onp.any(onp.isnan(result))
|
|
|
|
def testIssue453(self):
|
|
# https://github.com/google/jax/issues/453
|
|
a = onp.arange(6) + 1
|
|
ans = lnp.reshape(a, (3, 2), order='F')
|
|
expected = onp.reshape(a, (3, 2), order='F')
|
|
self.assertAllClose(ans, expected, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_op={}_dtype={}".format(
|
|
op, {bool: "bool", int: "int", float: "float"}[dtype]),
|
|
"dtype": dtype, "op": op}
|
|
for dtype in [int, float, bool]
|
|
for op in ["atleast_1d", "atleast_2d", "atleast_3d"]))
|
|
def testAtLeastNdLiterals(self, dtype, op):
|
|
# Fixes: https://github.com/google/jax/issues/634
|
|
onp_fun = lambda arg: getattr(onp, op)(arg)
|
|
lnp_fun = lambda arg: getattr(lnp, op)(arg)
|
|
args_maker = lambda: [dtype(2)]
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
|
|
def testLongLong(self):
|
|
self.assertAllClose(onp.int64(7), api.jit(lambda x: x)(onp.longlong(7)),
|
|
check_dtypes=True)
|
|
|
|
def testArange(self):
|
|
# test cases inspired by dask tests at
|
|
# https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92
|
|
self.assertAllClose(lnp.arange(77),
|
|
onp.arange(77), check_dtypes=True)
|
|
self.assertAllClose(lnp.arange(2, 13),
|
|
onp.arange(2, 13), check_dtypes=True)
|
|
self.assertAllClose(lnp.arange(4, 21, 9),
|
|
onp.arange(4, 21, 9), check_dtypes=True)
|
|
self.assertAllClose(lnp.arange(53, 5, -3),
|
|
onp.arange(53, 5, -3), check_dtypes=True)
|
|
# TODO(mattjj): make these tests work when jax_enable_x64=True
|
|
# self.assertAllClose(lnp.arange(77, dtype=float),
|
|
# onp.arange(77, dtype=float), check_dtypes=True)
|
|
# self.assertAllClose(lnp.arange(2, 13, dtype=int),
|
|
# onp.arange(2, 13, dtype=int), check_dtypes=True)
|
|
self.assertAllClose(lnp.arange(0, 1, -0.5),
|
|
onp.arange(0, 1, -0.5), check_dtypes=True)
|
|
|
|
self.assertRaises(TypeError, lambda: lnp.arange())
|
|
|
|
# test that lnp.arange(N) doesn't instantiate an ndarray
|
|
self.assertFalse(type(lnp.arange(77)) == type(onp.arange(77)))
|
|
self.assertTrue(type(lnp.arange(77)) == type(lax.iota(onp.int32, 77)))
|
|
|
|
# test that lnp.arange(N, dtype=int32) doesn't instantiate an ndarray
|
|
self.assertFalse(type(lnp.arange(77, dtype=lnp.int32)) ==
|
|
type(onp.arange(77, dtype=onp.int32)))
|
|
self.assertTrue(type(lnp.arange(77, dtype=lnp.int32)) ==
|
|
type(lax.iota(onp.int32, 77)))
|
|
|
|
def testIssue830(self):
|
|
a = lnp.arange(4, dtype=lnp.complex64)
|
|
self.assertEqual(a.dtype, lnp.complex64)
|
|
|
|
def testIssue728(self):
|
|
assert lnp.allclose(lnp.eye(5000), onp.eye(5000))
|
|
self.assertEqual(0, onp.sum(lnp.eye(1050) - onp.eye(1050)))
|
|
|
|
def testIssue746(self):
|
|
lnp.arange(12).reshape(3, 4) # doesn't crash
|
|
|
|
def testIssue764(self):
|
|
x = lnp.linspace(190, 200, 4)
|
|
f = api.grad(lambda x: lnp.sum(lnp.tanh(x)))
|
|
# Expected values computed with autograd in float64 precision.
|
|
expected = onp.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171,
|
|
7.66067839e-174], onp.float64)
|
|
self.assertAllClose(f(x), expected, check_dtypes=False)
|
|
|
|
def testIssue776(self):
|
|
"""Tests that the scatter-add transpose rule instantiates symbolic zeros."""
|
|
def f(u):
|
|
y = jax.ops.index_add(onp.ones(10,), [2, 4, 5], u)
|
|
# The transpose rule for lax.tie_in returns a symbolic zero for its first
|
|
# argument.
|
|
return lax.tie_in(y, 7.)
|
|
|
|
self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)),
|
|
check_dtypes=True)
|
|
|
|
def testIssue777(self):
|
|
x = lnp.linspace(-200, 0, 4, dtype=onp.float32)
|
|
f = api.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x))))
|
|
self.assertAllClose(f(x), onp.array([0., 0., 0., 0.25], dtype=onp.float32),
|
|
check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]),
|
|
"dtype": dtype, "op": op}
|
|
for dtype in float_dtypes
|
|
for op in ("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan",
|
|
"sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp",
|
|
"log", "expm1", "log1p")))
|
|
def testMathSpecialFloatValues(self, op, dtype):
|
|
onp_op = getattr(onp, op)
|
|
lnp_op = getattr(lnp, op)
|
|
dtype = onp.dtype(xla_bridge.canonicalize_dtype(dtype)).type
|
|
for x in (onp.nan, -onp.inf, -100., -2. -1., 0., 1., 2., 100., onp.inf,
|
|
onp.finfo(dtype).max, onp.sqrt(onp.finfo(dtype).max),
|
|
onp.sqrt(onp.finfo(dtype).max) * 2.):
|
|
if onp.isnan(x) and op in ("sinh", "cosh", "expm1", "exp"):
|
|
# TODO(b/133842876, b/133842870): these return wrong outputs on CPU for
|
|
# NaN inputs.
|
|
continue
|
|
if (op in ("sin", "cos", "tan", "arctan") and FLAGS.jax_test_dut and
|
|
FLAGS.jax_test_dut.startswith("tpu")):
|
|
continue # TODO(b/132196789, b/134175194): fix and reenable.
|
|
x = dtype(x)
|
|
expected = onp_op(x)
|
|
actual = lnp_op(x)
|
|
self.assertAllClose(expected, actual, check_dtypes=True)
|
|
|
|
def testIssue883(self):
|
|
# from https://github.com/google/jax/issues/883
|
|
|
|
@partial(api.jit, static_argnums=(1,))
|
|
def f(x, v):
|
|
return x
|
|
|
|
x = lnp.ones((10, 10))
|
|
v = lnp.array([1, 2, 3])
|
|
first_call = f(x, v)
|
|
second_call = f(x, v) # doesn't crash
|
|
|
|
def testReductionOfOutOfBoundsAxis(self): # Issue 888
|
|
x = lnp.ones((3, 4))
|
|
self.assertRaises(ValueError, lambda: lnp.sum(x, axis=2))
|
|
|
|
def testIssue956(self):
|
|
self.assertRaises(TypeError, lambda: lnp.ndarray((1, 1)))
|
|
|
|
@parameterized.named_parameters(
|
|
jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
|
|
.format(shape, dtype, out_dtype, axis, ddof, keepdims),
|
|
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
|
|
"ddof": ddof, "keepdims": keepdims, "rng": rng}
|
|
for shape in [(5,), (10, 5)]
|
|
for dtype in all_dtypes
|
|
for out_dtype in number_dtypes
|
|
for axis in [None, 0, -1]
|
|
for ddof in [0, 1, 2]
|
|
for keepdims in [False, True]
|
|
for rng in [jtu.rand_default()]))
|
|
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims, rng):
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
|
onp_fun = partial(onp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
|
lnp_fun = partial(lnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
|
|
shape, dtype, rowvar, ddof, bias),
|
|
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
|
|
"bias": bias, "rng": rng}
|
|
for shape in [(5,), (10, 5), (3, 10)]
|
|
for dtype in all_dtypes
|
|
for rowvar in [True, False]
|
|
for bias in [True, False]
|
|
for ddof in [None, 2, 3]
|
|
for rng in [jtu.rand_default()]))
|
|
@jtu.skip_on_devices("gpu") # TODO(b/138003641): test fails on GPU.
|
|
def testCov(self, shape, dtype, rowvar, ddof, bias, rng):
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
|
onp_fun = partial(onp.cov, rowvar=rowvar, ddof=ddof, bias=bias)
|
|
lnp_fun = partial(lnp.cov, rowvar=rowvar, ddof=ddof, bias=bias)
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
def testIssue967(self):
|
|
self.assertRaises(TypeError, lambda: lnp.zeros(1.5))
|
|
|
|
@parameterized.named_parameters(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format(
|
|
shape, dtype, rowvar, ddof, bias),
|
|
"shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof,
|
|
"bias": bias, "rng": rng}
|
|
for shape in [(5,), (10, 5), (3, 10)]
|
|
for dtype in number_dtypes
|
|
for rowvar in [True, False]
|
|
for bias in [True, False]
|
|
for ddof in [None, 2, 3]
|
|
for rng in [jtu.rand_default()]))
|
|
def testCorrCoef(self, shape, dtype, rowvar, ddof, bias, rng):
|
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
|
mat = onp.asarray([rng(shape, dtype)])
|
|
onp_fun = partial(onp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias)
|
|
lnp_fun = partial(lnp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias)
|
|
if not onp.any(onp.isclose(onp.std(mat), 0.0)):
|
|
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
|
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|