2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import collections
|
|
|
|
import functools
|
|
|
|
from functools import partial
|
|
|
|
import itertools
|
2019-03-19 16:38:42 -07:00
|
|
|
from unittest import skip, SkipTest
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
|
|
|
import numpy as onp
|
|
|
|
import numpy.random as npr
|
2019-05-13 17:07:49 -04:00
|
|
|
import six
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from jax import api
|
|
|
|
from jax import core
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
from jax import dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
from jax import lax
|
|
|
|
from jax import test_util as jtu
|
|
|
|
from jax import lax_reference
|
2019-11-15 10:02:51 -05:00
|
|
|
from jax import dtypes
|
2018-12-17 17:20:52 -08:00
|
|
|
from jax.test_util import check_grads
|
2018-11-17 18:03:33 -08:00
|
|
|
from jax.interpreters import xla
|
2019-08-04 12:34:03 -04:00
|
|
|
from jax.lib import xla_client
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-12 09:00:39 -08:00
|
|
|
from jax.config import config
|
2018-12-06 18:37:59 -05:00
|
|
|
config.parse_flags_with_absl()
|
2018-11-29 12:30:34 -08:00
|
|
|
FLAGS = config.FLAGS
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def num_float_bits(dtype):
|
2019-11-15 10:02:51 -05:00
|
|
|
return dtypes.finfo(dtypes.canonicalize_dtype(dtype)).bits
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
### lax tests
|
|
|
|
|
|
|
|
# For standard unops and binops, we can generate a large number of tests on
|
|
|
|
# arguments of appropriate shapes and dtypes using the following table.
|
|
|
|
|
2019-10-22 19:53:59 -04:00
|
|
|
float_dtypes = list(jtu.supported_dtypes().intersection(
|
2019-11-20 22:43:46 -05:00
|
|
|
{dtypes.bfloat16, onp.float16, onp.float32, onp.float64}))
|
2019-12-06 14:49:27 -05:00
|
|
|
complex_elem_dtypes = list(jtu.supported_dtypes().intersection(
|
|
|
|
{onp.float32, onp.float64}))
|
|
|
|
complex_dtypes = list(jtu.supported_dtypes().intersection(
|
|
|
|
{onp.complex64, onp.complex128}))
|
2019-02-01 11:07:45 -05:00
|
|
|
inexact_dtypes = float_dtypes + complex_dtypes
|
2019-12-06 14:49:27 -05:00
|
|
|
int_dtypes = list(jtu.supported_dtypes().intersection({onp.int32, onp.int64}))
|
2018-11-17 18:03:33 -08:00
|
|
|
bool_dtypes = [onp.bool_]
|
|
|
|
default_dtypes = float_dtypes + int_dtypes
|
|
|
|
all_dtypes = float_dtypes + complex_dtypes + int_dtypes + bool_dtypes
|
|
|
|
|
|
|
|
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
|
|
|
|
|
2019-10-22 19:53:59 -04:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
OpRecord = collections.namedtuple(
|
|
|
|
"OpRecord", ["op", "nargs", "dtypes", "rng_factory", "tol"])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
def op_record(op, nargs, dtypes, rng_factory, tol=None):
|
|
|
|
return OpRecord(op, nargs, dtypes, rng_factory, tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
LAX_OPS = [
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("neg", 1, default_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("sign", 1, default_dtypes, jtu.rand_small),
|
|
|
|
op_record("floor", 1, float_dtypes, jtu.rand_small),
|
|
|
|
op_record("ceil", 1, float_dtypes, jtu.rand_small),
|
|
|
|
op_record("round", 1, float_dtypes, jtu.rand_default),
|
2019-12-11 16:41:24 -05:00
|
|
|
op_record("nextafter", 2, [f for f in float_dtypes if f != dtypes.bfloat16],
|
|
|
|
jtu.rand_default, tol=0),
|
2019-10-21 10:56:54 -04:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("is_finite", 1, float_dtypes, jtu.rand_small),
|
2019-10-21 10:56:54 -04:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("exp", 1, float_dtypes + complex_dtypes, jtu.rand_small),
|
2019-10-22 19:53:59 -04:00
|
|
|
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
|
|
|
|
# precision.
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("expm1", 1, float_dtypes + complex_dtypes, jtu.rand_small,
|
2019-10-22 19:53:59 -04:00
|
|
|
{onp.float64: 1e-8}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("log", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
|
|
|
op_record("log1p", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
2019-10-22 19:53:59 -04:00
|
|
|
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
|
|
|
|
# ~float32 precision.
|
|
|
|
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("tanh", 1, float_dtypes + complex_dtypes, jtu.rand_small,
|
2019-10-22 19:53:59 -04:00
|
|
|
{onp.float64: 1e-9, onp.complex128: 1e-7}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("sin", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("cos", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("atan2", 2, float_dtypes, jtu.rand_default),
|
|
|
|
|
|
|
|
op_record("sqrt", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
|
|
|
op_record("rsqrt", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
|
|
|
op_record("square", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("reciprocal", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
|
|
|
op_record("tan", 1, float_dtypes, jtu.rand_default, {onp.float32: 1e-5}),
|
|
|
|
op_record("asin", 1, float_dtypes, jtu.rand_small),
|
|
|
|
op_record("acos", 1, float_dtypes, jtu.rand_small),
|
|
|
|
op_record("atan", 1, float_dtypes, jtu.rand_small),
|
|
|
|
op_record("sinh", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("cosh", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
|
|
|
|
op_record("lgamma", 1, float_dtypes, jtu.rand_positive,
|
2019-11-16 13:51:42 -05:00
|
|
|
{onp.float32: 1e-3 if jtu.device_under_test() == "tpu" else 1e-5,
|
|
|
|
onp.float64: 1e-14}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("digamma", 1, float_dtypes, jtu.rand_positive,
|
2019-10-22 19:53:59 -04:00
|
|
|
{onp.float64: 1e-14}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("erf", 1, float_dtypes, jtu.rand_small),
|
|
|
|
op_record("erfc", 1, float_dtypes, jtu.rand_small),
|
2019-10-22 19:53:59 -04:00
|
|
|
# TODO(b/142976030): the approximation of erfinf used by XLA is only
|
|
|
|
# accurate to float32 precision.
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("erf_inv", 1, float_dtypes, jtu.rand_small,
|
2019-10-22 19:53:59 -04:00
|
|
|
{onp.float64: 1e-9}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("bessel_i0e", 1, float_dtypes, jtu.rand_default),
|
|
|
|
op_record("bessel_i1e", 1, float_dtypes, jtu.rand_default),
|
2019-10-21 10:56:54 -04:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("real", 1, complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("imag", 1, complex_dtypes, jtu.rand_default),
|
2019-12-06 14:49:27 -05:00
|
|
|
op_record("complex", 2, complex_elem_dtypes, jtu.rand_default),
|
|
|
|
op_record("conj", 1, complex_elem_dtypes + complex_dtypes,
|
2019-11-11 12:51:15 -08:00
|
|
|
jtu.rand_default),
|
|
|
|
op_record("abs", 1, default_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("pow", 2, float_dtypes + complex_dtypes, jtu.rand_positive),
|
|
|
|
|
|
|
|
op_record("bitwise_and", 2, bool_dtypes, jtu.rand_small),
|
|
|
|
op_record("bitwise_not", 1, bool_dtypes, jtu.rand_small),
|
|
|
|
op_record("bitwise_or", 2, bool_dtypes, jtu.rand_small),
|
|
|
|
op_record("bitwise_xor", 2, bool_dtypes, jtu.rand_small),
|
|
|
|
|
|
|
|
op_record("add", 2, default_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("sub", 2, default_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("mul", 2, default_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("div", 2, default_dtypes + complex_dtypes, jtu.rand_nonzero),
|
|
|
|
op_record("rem", 2, default_dtypes, jtu.rand_nonzero),
|
|
|
|
|
|
|
|
op_record("max", 2, all_dtypes, jtu.rand_small),
|
|
|
|
op_record("min", 2, all_dtypes, jtu.rand_small),
|
|
|
|
|
|
|
|
op_record("eq", 2, all_dtypes, jtu.rand_some_equal),
|
|
|
|
op_record("ne", 2, all_dtypes, jtu.rand_small),
|
|
|
|
op_record("ge", 2, default_dtypes, jtu.rand_small),
|
|
|
|
op_record("gt", 2, default_dtypes, jtu.rand_small),
|
|
|
|
op_record("le", 2, default_dtypes, jtu.rand_small),
|
|
|
|
op_record("lt", 2, default_dtypes, jtu.rand_small),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
|
|
|
|
|
|
|
CombosWithReplacement = itertools.combinations_with_replacement
|
|
|
|
|
|
|
|
|
|
|
|
class LaxTest(jtu.JaxTestCase):
|
|
|
|
"""Numerical tests for LAX operations."""
|
|
|
|
|
2018-12-11 08:54:35 -08:00
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": jtu.format_test_name_suffix(
|
2019-10-21 10:56:54 -04:00
|
|
|
rec.op, shapes, itertools.repeat(dtype)),
|
2019-11-11 12:51:15 -08:00
|
|
|
"op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes,
|
|
|
|
"dtype": dtype}
|
2018-12-11 08:54:35 -08:00
|
|
|
for shape_group in compatible_shapes
|
|
|
|
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
|
|
|
for dtype in rec.dtypes)
|
|
|
|
for rec in LAX_OPS))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testOp(self, op_name, rng_factory, shapes, dtype):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2019-10-21 10:56:54 -04:00
|
|
|
op = getattr(lax, op_name)
|
2018-11-17 18:03:33 -08:00
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-11 08:54:35 -08:00
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": jtu.format_test_name_suffix(
|
2019-10-21 10:56:54 -04:00
|
|
|
rec.op, shapes, itertools.repeat(dtype)),
|
2019-11-11 12:51:15 -08:00
|
|
|
"op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes,
|
|
|
|
"dtype": dtype, "tol": rec.tol}
|
2018-12-11 08:54:35 -08:00
|
|
|
for shape_group in compatible_shapes
|
|
|
|
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
|
|
|
for dtype in rec.dtypes)
|
|
|
|
for rec in LAX_OPS))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2019-10-21 10:56:54 -04:00
|
|
|
op = getattr(lax, op_name)
|
|
|
|
numpy_op = getattr(lax_reference, op_name)
|
2019-11-16 13:51:42 -05:00
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker, tol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# TODO test shift_left, shift_right_arithmetic, shift_right_logical
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
|
|
|
|
from_dtype, to_dtype),
|
2019-11-11 12:51:15 -08:00
|
|
|
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
|
|
|
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConvertElementType(self, from_dtype, to_dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
|
|
|
op = lambda x: lax.convert_element_type(x, to_dtype)
|
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_from_dtype={}_to_dtype={}"
|
|
|
|
.format(from_dtype, to_dtype),
|
2019-11-11 12:51:15 -08:00
|
|
|
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
|
|
|
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConvertElementTypeAgainstNumpy(self, from_dtype, to_dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
|
|
|
op = lambda x: lax.convert_element_type(x, to_dtype)
|
|
|
|
numpy_op = lambda x: lax_reference.convert_element_type(x, to_dtype)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_from_dtype={}_to_dtype={}"
|
|
|
|
.format(from_dtype, to_dtype),
|
2019-11-11 12:51:15 -08:00
|
|
|
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
|
|
|
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBitcastConvertType(self, from_dtype, to_dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
|
|
|
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_from_dtype={}_to_dtype={}"
|
|
|
|
.format(from_dtype, to_dtype),
|
2019-11-11 12:51:15 -08:00
|
|
|
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
|
|
|
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
|
|
|
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
|
|
|
numpy_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(min_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(operand_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(max_shape, dtype)),
|
|
|
|
"min_shape": min_shape, "operand_shape": operand_shape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"max_shape": max_shape, "dtype": dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for min_shape, operand_shape, max_shape in [
|
|
|
|
[(), (2, 3), ()],
|
|
|
|
[(2, 3), (2, 3), ()],
|
|
|
|
[(), (2, 3), (2, 3)],
|
|
|
|
[(2, 3), (2, 3), (2, 3)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testClamp(self, min_shape, operand_shape, max_shape, dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [min_shape, operand_shape, max_shape]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
|
|
|
self._CompileAndCheck(lax.clamp, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(min_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(operand_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(max_shape, dtype)),
|
|
|
|
"min_shape": min_shape, "operand_shape": operand_shape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"max_shape": max_shape, "dtype": dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for min_shape, operand_shape, max_shape in [
|
|
|
|
[(), (2, 3), ()],
|
|
|
|
[(2, 3), (2, 3), ()],
|
|
|
|
[(), (2, 3), (2, 3)],
|
|
|
|
[(2, 3), (2, 3), (2, 3)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testClampAgainstNumpy(self, min_shape, operand_shape, max_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [min_shape, operand_shape, max_shape]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
|
|
|
self._CheckAgainstNumpy(lax.clamp, lax_reference.clamp, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
|
|
|
dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
|
|
|
|
num_arrs),
|
|
|
|
"dim": dim, "base_shape": base_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"num_arrs": num_arrs, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for num_arrs in [3]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
for dim in range(len(base_shape))
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConcatenate(self, dim, base_shape, dtype, num_arrs, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
|
|
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
|
|
|
op = lambda *args: lax.concatenate(args, dim)
|
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
|
|
|
dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
|
|
|
|
num_arrs),
|
|
|
|
"dim": dim, "base_shape": base_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"num_arrs": num_arrs, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for num_arrs in [3]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
for dim in range(len(base_shape))
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
|
|
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
|
|
|
op = lambda *args: lax.concatenate(args, dim)
|
|
|
|
numpy_op = lambda *args: lax_reference.concatenate(args, dim)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"strides": strides, "padding": padding, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i, 9, 10), (j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
for strides in [(1, 1), (1, 2), (2, 1)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testConv(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv(lhs, rhs, strides, padding)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"strides": strides, "padding": padding, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i, 9, 10), (j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
for strides in [(1, 1), (1, 2), (2, 1)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvAgainstNumpy(self, lhs_shape, rhs_shape, dtype, strides, padding,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
op = lambda lhs, rhs: lax.conv(lhs, rhs, strides, padding)
|
|
|
|
numpy_op = lambda lhs, rhs: lax_reference.conv(lhs, rhs, strides, padding)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
|
|
|
"_lhs_dilation={}_rhs_dilation={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
strides, padding, lhs_dilation, rhs_dilation),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rhs_dilation": rhs_dilation, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i, 9, 10), (j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([1, 2, 3], repeat=3)]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for strides in [(1, 1), (1, 2), (2, 1)]
|
2018-11-17 18:03:33 -08:00
|
|
|
for padding in [((0, 0), (0, 0)), ((1, 2), (2, 0))]
|
|
|
|
for lhs_dilation, rhs_dilation in itertools.product(
|
|
|
|
[(1, 1), (1, 2), (2, 2)], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvWithGeneralPadding(self, lhs_shape, rhs_shape, dtype, strides,
|
2019-11-11 12:51:15 -08:00
|
|
|
padding, lhs_dilation, rhs_dilation, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_with_general_padding(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
|
|
|
"_lhs_dilation={}_rhs_dilation={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
strides, padding, lhs_dilation, rhs_dilation),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rhs_dilation": rhs_dilation, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i, 9, 10), (j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([1, 2, 3], repeat=3)]
|
|
|
|
for dtype in [onp.float32] for strides in [(1, 1), (1, 2), (2, 1)]
|
|
|
|
for padding in [((0, 0), (0, 0)), ((1, 2), (2, 0))]
|
|
|
|
for lhs_dilation, rhs_dilation in itertools.product(
|
|
|
|
[(1, 1), (1, 2), (2, 2)], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def DISABLED_testConvWithGeneralPaddingAgainstNumpy(
|
|
|
|
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation,
|
2019-11-11 12:51:15 -08:00
|
|
|
rhs_dilation, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
# TODO(mattjj): make this test pass
|
2019-05-06 12:30:22 -07:00
|
|
|
raise SkipTest("this test is incomplete")
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_with_general_padding(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
|
|
|
|
|
|
|
def numpy_fun(lhs, rhs):
|
|
|
|
return lax_reference.conv_with_general_padding(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(fun, numpy_fun, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
|
|
|
"_lhs_dilation={}_rhs_dilation={}"
|
|
|
|
"_dims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
strides, padding, lhs_dilation, rhs_dilation,
|
|
|
|
",".join(dim_nums)),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
|
|
|
|
"rhs_dilation": rhs_dilation, "dimension_numbers": dim_nums,
|
2019-11-11 12:51:15 -08:00
|
|
|
"perms": perms, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
2019-12-02 14:43:43 -05:00
|
|
|
((b, i, 9, w), (j, i, 4, 5))
|
|
|
|
for w in [0, 10]
|
2018-11-17 18:03:33 -08:00
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes for strides in [(1, 1), (2, 1)]
|
2019-12-02 14:43:43 -05:00
|
|
|
for padding in [((1, 2), (2, 0)), ((10, 8), (7, 13))]
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_dilation, rhs_dilation in itertools.product(
|
2019-12-02 14:43:43 -05:00
|
|
|
[(1, 1), (1, 2), (1, 4)], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]
|
2019-02-15 13:16:27 -05:00
|
|
|
for dim_nums, perms in [
|
|
|
|
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
|
|
|
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
|
|
|
|
(("NCHW", "HWIO", "NHWC"), ([0, 1, 2, 3], [2, 3, 1, 0])),
|
|
|
|
]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvGeneralDilated(self, lhs_shape, rhs_shape, dtype, strides,
|
|
|
|
padding, lhs_dilation, rhs_dilation,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimension_numbers, perms, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs_perm, rhs_perm = perms # permute to compatible shapes
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
return [lax.transpose(rng(lhs_shape, dtype), lhs_perm),
|
|
|
|
lax.transpose(rng(rhs_shape, dtype), rhs_perm)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_general_dilated(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
|
|
|
|
dimension_numbers)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
|
|
|
# TODO(mattjj): test conv_general_dilated against numpy
|
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
@staticmethod
|
|
|
|
def _conv_transpose_via_grad(data, kernel, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=None, dimension_numbers=None):
|
2019-04-11 09:12:21 -07:00
|
|
|
"""Helper method: calculates conv transpose via grad for testing."""
|
2019-04-09 22:59:03 -07:00
|
|
|
assert len(data.shape) == len(kernel.shape)
|
|
|
|
nspatial = len(data.shape) - 2
|
|
|
|
one = (1,) * nspatial
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation = rhs_dilation or one
|
2019-04-09 22:59:03 -07:00
|
|
|
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
|
|
|
dimension_numbers)
|
|
|
|
in_shape = onp.take(data.shape, dn.lhs_spec)
|
|
|
|
in_sdims = in_shape[2:]
|
|
|
|
k_shape = onp.take(kernel.shape, dn.rhs_spec)
|
|
|
|
k_sdims = k_shape[2:]
|
2019-12-17 02:03:17 +00:00
|
|
|
e_k_sdims = [(k-1) * r + 1 for k, r in zip(k_sdims, rhs_dilation)]
|
2019-04-09 22:59:03 -07:00
|
|
|
if padding == 'VALID':
|
2019-12-17 02:03:17 +00:00
|
|
|
o_sdims = [in_sdims[i]*strides[i] + max(e_k_sdims[i]-strides[i],0)
|
2019-04-09 22:59:03 -07:00
|
|
|
for i in range(nspatial)]
|
|
|
|
elif padding == 'SAME':
|
|
|
|
o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)]
|
|
|
|
o_shape = [in_shape[0], k_shape[1]] + o_sdims
|
|
|
|
out_spec_inv = [x[0] for x in
|
|
|
|
sorted(enumerate(dn.out_spec), key=lambda x: x[1])]
|
|
|
|
o_layout = onp.take(onp.array(o_shape), out_spec_inv)
|
|
|
|
placeholder = onp.ones(o_layout, data.dtype)
|
|
|
|
conv = lambda x: lax.conv_general_dilated(x, kernel, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
one, rhs_dilation, dn)
|
2019-04-09 22:59:03 -07:00
|
|
|
_, g = api.vjp(conv, placeholder)
|
|
|
|
return g(data)[0]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _transpose_conv_kernel(data, kernel, dimension_numbers):
|
|
|
|
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
|
|
|
dimension_numbers)
|
|
|
|
spatial_axes = onp.array(dn.rhs_spec)[2:]
|
|
|
|
for axis in spatial_axes:
|
|
|
|
kernel = onp.flip(kernel, axis)
|
|
|
|
kernel = onp.swapaxes(kernel, dn.rhs_spec[0], dn.rhs_spec[1])
|
|
|
|
return kernel
|
|
|
|
|
2019-04-09 15:06:46 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
2019-12-17 02:03:17 +00:00
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_rhs_dilation={}".format(
|
2019-04-09 15:06:46 -07:00
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
2019-12-17 02:03:17 +00:00
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, rhs_dilation),
|
2019-04-09 15:06:46 -07:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-12-17 02:03:17 +00:00
|
|
|
"strides": strides, "padding": padding, "rhs_dilation": rhs_dilation,
|
|
|
|
"rng_factory": rng_factory, 'dspec': dspec}
|
2019-04-09 15:06:46 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
2019-04-09 22:59:03 -07:00
|
|
|
((b, 9, 10, i), (k, k, j, i)) # NB: i,j flipped in RHS for transpose
|
|
|
|
for b, i, j, k in itertools.product([2,3],[2,3],[2,3],[3,4,5])]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2019-04-09 15:06:46 -07:00
|
|
|
for strides in [(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
|
|
|
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
|
2019-12-17 02:03:17 +00:00
|
|
|
for rhs_dilation in [None, (2, 2)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2019-04-09 22:59:03 -07:00
|
|
|
def testConvTranspose2DT(self, lhs_shape, rhs_shape, dtype, strides,
|
2019-12-17 02:03:17 +00:00
|
|
|
padding, dspec, rhs_dilation, rng_factory):
|
2019-11-11 12:51:15 -08:00
|
|
|
rng = rng_factory()
|
2019-04-09 22:59:03 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
# NB: this test calculates conv_transpose performing identically to the
|
|
|
|
# lhs-grad of conv.
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec,
|
|
|
|
transpose_kernel=True)
|
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec)
|
|
|
|
|
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
|
|
|
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
2019-12-17 02:03:17 +00:00
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_rhs_dilation={}".format(
|
2019-04-09 22:59:03 -07:00
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
2019-12-17 02:03:17 +00:00
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, rhs_dilation),
|
2019-04-09 22:59:03 -07:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-12-17 02:03:17 +00:00
|
|
|
"strides": strides, "padding": padding, "rhs_dilation": rhs_dilation,
|
|
|
|
"rng_factory": rng_factory, 'dspec': dspec}
|
2019-04-09 22:59:03 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, 9, 10, i), (k, k, i, j))
|
|
|
|
for b, i, j, k in itertools.product([2,3],[2,3],[2,3],[3,4,5])]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2019-04-09 22:59:03 -07:00
|
|
|
for strides in [(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
|
|
|
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
|
2019-12-17 02:03:17 +00:00
|
|
|
for rhs_dilation in [None, (2, 2)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2019-04-09 22:59:03 -07:00
|
|
|
def testConvTranspose2D(self, lhs_shape, rhs_shape, dtype, strides,
|
2019-12-17 02:03:17 +00:00
|
|
|
padding, dspec, rhs_dilation, rng_factory):
|
2019-11-11 12:51:15 -08:00
|
|
|
rng = rng_factory()
|
2019-04-09 22:59:03 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec,
|
|
|
|
transpose_kernel=False)
|
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
|
|
|
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec)
|
|
|
|
|
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
|
|
|
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
2019-12-17 02:03:17 +00:00
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_rhs_dilation={}".format(
|
2019-04-09 22:59:03 -07:00
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
2019-12-17 02:03:17 +00:00
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, rhs_dilation),
|
2019-04-09 22:59:03 -07:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-12-17 02:03:17 +00:00
|
|
|
"strides": strides, "padding": padding, "rhs_dilation": rhs_dilation,
|
|
|
|
"rng_factory": rng_factory, 'dspec': dspec}
|
2019-04-09 22:59:03 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, 10, i), (k, i, j))
|
|
|
|
for b, i, j, k in itertools.product([2,3],[2,3],[2,3],[3,4,5])]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2019-04-09 22:59:03 -07:00
|
|
|
for strides in [(1,), (2,), (3,)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
|
|
|
for dspec in [('NHC', 'HIO', 'NHC'),]
|
2019-12-17 02:03:17 +00:00
|
|
|
for rhs_dilation in [None, (2,)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2019-04-09 22:59:03 -07:00
|
|
|
def testConvTranspose1D(self, lhs_shape, rhs_shape, dtype, strides,
|
2019-12-17 02:03:17 +00:00
|
|
|
padding, dspec, rhs_dilation, rng_factory):
|
2019-11-11 12:51:15 -08:00
|
|
|
rng = rng_factory()
|
2019-04-09 15:06:46 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
transpose_kernel=False)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
2019-04-09 22:59:03 -07:00
|
|
|
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
2019-04-09 15:06:46 -07:00
|
|
|
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2019-06-28 09:00:32 -04:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}_precision={}".format(
|
2018-11-17 18:03:33 -08:00
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
2019-06-28 09:00:32 -04:00
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
precision),
|
2018-11-17 18:03:33 -08:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"precision": precision, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-06-28 12:48:44 -04:00
|
|
|
for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH,
|
|
|
|
lax.Precision.HIGHEST]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testDot(self, lhs_shape, rhs_shape, dtype, precision, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
2019-06-28 09:00:32 -04:00
|
|
|
self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker,
|
|
|
|
check_dtypes=True)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype)),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testDotAgainstNumpy(self, lhs_shape, rhs_shape, dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
2019-11-20 22:43:46 -05:00
|
|
|
tol = {
|
|
|
|
onp.float16: 1e-2,
|
2019-12-16 20:48:19 -05:00
|
|
|
onp.float64: max(jtu.default_tolerance()[onp.dtype(onp.float64)], 1e-14),
|
|
|
|
onp.complex128: max(jtu.default_tolerance()[onp.dtype(onp.complex128)],
|
|
|
|
1e-14)
|
2019-11-20 22:43:46 -05:00
|
|
|
}
|
2019-11-16 13:51:42 -05:00
|
|
|
lax_op = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
|
|
|
self._CheckAgainstNumpy(lax_op, lax_reference.dot, args_maker, tol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
lhs_contracting, rhs_contracting),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
|
2019-06-17 20:44:33 -07:00
|
|
|
[(3, 5), (2, 5), [1], [1]],
|
|
|
|
[(5, 3), (5, 2), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
|
|
|
|
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
|
2018-11-17 18:03:33 -08:00
|
|
|
[(3, 2), (2, 4), [1], [0]],
|
|
|
|
]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
lhs_contracting, rhs_contracting, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.dot_general(lhs, rhs, dimension_numbers)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=False)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimension_numbers": dimension_numbers, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape, dimension_numbers in [
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
|
|
|
]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimension_numbers, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.dot_general(lhs, rhs, dimension_numbers)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=False)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimension_numbers": dimension_numbers, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape, dimension_numbers in [
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
|
|
|
]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimension_numbers, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
op = lambda x, y: lax.dot_general(x, y, dimension_numbers)
|
|
|
|
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
|
|
|
|
shape, onp.dtype(dtype).name, broadcast_sizes),
|
|
|
|
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(), (2, 3)]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for broadcast_sizes in [(), (2,), (1, 2)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcast(self, shape, dtype, broadcast_sizes, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast(x, broadcast_sizes)
|
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_broadcast_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), broadcast_sizes),
|
|
|
|
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(), (2, 3)]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for broadcast_sizes in [(), (2,), (1, 2)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcastAgainstNumpy(self, shape, dtype, broadcast_sizes, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast(x, broadcast_sizes)
|
|
|
|
numpy_op = lambda x: lax_reference.broadcast(x, broadcast_sizes)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(inshape, dtype),
|
|
|
|
outshape, broadcast_dimensions),
|
|
|
|
"inshape": inshape, "dtype": dtype, "outshape": outshape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimensions": broadcast_dimensions, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for inshape, outshape, broadcast_dimensions in [
|
|
|
|
([2], [2, 2], [0]),
|
|
|
|
([2], [2, 2], [1]),
|
|
|
|
([2], [2, 3], [0]),
|
|
|
|
([], [2, 3], []),
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(inshape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(inshape, dtype),
|
|
|
|
outshape, broadcast_dimensions),
|
|
|
|
"inshape": inshape, "dtype": dtype, "outshape": outshape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimensions": broadcast_dimensions, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for inshape, outshape, broadcast_dimensions in [
|
|
|
|
([2], [2, 2], [0]),
|
|
|
|
([2], [2, 2], [1]),
|
|
|
|
([2], [2, 3], [0]),
|
|
|
|
([], [2, 3], []),
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testBroadcastInDimAgainstNumpy(self, inshape, dtype, outshape,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimensions, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(inshape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
|
|
|
numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(out_shape, dtype)),
|
|
|
|
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in default_dtypes
|
|
|
|
for arg_shape, out_shape in [
|
|
|
|
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testReshape(self, arg_shape, out_shape, dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
|
|
op = lambda x: lax.reshape(x, out_shape)
|
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(out_shape, dtype)),
|
|
|
|
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in default_dtypes
|
|
|
|
for arg_shape, out_shape in [
|
|
|
|
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testReshapeAgainstNumpy(self, arg_shape, out_shape, dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
|
|
op = lambda x: lax.reshape(x, out_shape)
|
|
|
|
numpy_op = lambda x: lax_reference.reshape(x, out_shape)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_pads={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(2, 3)]
|
|
|
|
for dtype in default_dtypes
|
2018-12-06 18:37:59 -05:00
|
|
|
for pads in [[(1, 2, 1), (0, 1, 0)]]))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testPad(self, shape, dtype, pads, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_pads={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(2, 3)]
|
|
|
|
for dtype in default_dtypes
|
2018-12-06 18:37:59 -05:00
|
|
|
for pads in [[(1, 2, 1), (0, 1, 0)]]))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testPadAgainstNumpy(self, shape, dtype, pads, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.pad(x, onp.array(0, dtype), pads)
|
|
|
|
numpy_op = lambda x: lax_reference.pad(x, onp.array(0, dtype), pads)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
|
|
|
def testReverse(self):
|
|
|
|
rev = api.jit(lambda operand: lax.rev(operand, dimensions))
|
|
|
|
|
|
|
|
dimensions = [0]
|
|
|
|
self.assertAllClose(onp.array([3, 2, 1]), rev(onp.array([1, 2, 3])),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
dimensions = [0, 1]
|
|
|
|
self.assertAllClose(onp.array([[6, 5, 4], [3, 2, 1]]),
|
|
|
|
rev(onp.array([[1, 2, 3], [4, 5, 6]])),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_predshape={}_argshapes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, arg_dtype)),
|
|
|
|
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape in [(), (3,), (2, 3)]
|
|
|
|
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
|
|
|
for arg_dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSelect(self, pred_shape, arg_shape, arg_dtype, rng_factory):
|
2018-11-17 18:03:33 -08:00
|
|
|
def args_maker():
|
|
|
|
return [rng(pred_shape, onp.bool_), rng(arg_shape, arg_dtype),
|
|
|
|
rng(arg_shape, arg_dtype)]
|
2019-11-11 12:51:15 -08:00
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
return self._CompileAndCheck(lax.select, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_predshape={}_argshapes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, arg_dtype)),
|
|
|
|
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape in [(), (3,), (2, 3)]
|
|
|
|
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
|
|
|
for arg_dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSelectAgainstNumpy(self, pred_shape, arg_shape, arg_dtype, rng_factory):
|
2018-11-17 18:03:33 -08:00
|
|
|
def args_maker():
|
|
|
|
return [rng(pred_shape, onp.bool_), rng(arg_shape, arg_dtype),
|
|
|
|
rng(arg_shape, arg_dtype)]
|
2019-11-11 12:51:15 -08:00
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
return self._CheckAgainstNumpy(lax.select, lax_reference.select, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, limit_indices, strides),
|
|
|
|
"shape": shape, "dtype": dtype, "starts": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"limits": limit_indices, "strides": strides, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, limit_indices, strides in [
|
|
|
|
[(3,), (1,), (2,), None],
|
|
|
|
[(7,), (4,), (7,), None],
|
|
|
|
[(5,), (1,), (5,), (2,)],
|
|
|
|
[(8,), (1,), (6,), (2,)],
|
|
|
|
[(5, 3), (1, 1), (3, 2), None],
|
|
|
|
[(5, 3), (1, 1), (3, 1), None],
|
|
|
|
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
|
|
|
|
[(5, 3), (1, 1), (2, 1), (1, 1)],
|
|
|
|
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSlice(self, shape, dtype, starts, limits, strides, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.slice(x, starts, limits, strides)
|
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, limit_indices, strides),
|
|
|
|
"shape": shape, "dtype": dtype, "starts": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"limits": limit_indices, "strides": strides, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, limit_indices, strides in [
|
|
|
|
[(3,), (1,), (2,), None],
|
|
|
|
[(7,), (4,), (7,), None],
|
|
|
|
[(5,), (1,), (5,), (2,)],
|
|
|
|
[(8,), (1,), (6,), (2,)],
|
|
|
|
[(5, 3), (1, 1), (3, 2), None],
|
|
|
|
[(5, 3), (1, 1), (3, 1), None],
|
|
|
|
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
|
|
|
|
[(5, 3), (1, 1), (2, 1), (1, 1)],
|
|
|
|
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testSliceAgainstNumpy(self, shape, dtype, starts, limits,
|
2019-11-11 12:51:15 -08:00
|
|
|
strides, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.slice(x, starts, limits, strides)
|
|
|
|
numpy_op = lambda x: lax_reference.slice(x, starts, limits, strides)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, size_indices),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"size_indices": size_indices, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, size_indices in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testDynamicSlice(self, shape, dtype, start_indices, size_indices, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)]
|
|
|
|
op = lambda x, starts: lax.dynamic_slice(x, starts, size_indices)
|
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, size_indices),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"size_indices": size_indices, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, size_indices in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDynamicSliceAgainstNumpy(self, shape, dtype, start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
size_indices, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)]
|
|
|
|
op = lambda x, s: lax.dynamic_slice(x, s, size_indices)
|
|
|
|
numpy_op = lambda x, s: lax_reference.dynamic_slice(x, s, size_indices)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, update_shape),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, update_shape in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDynamicUpdateSlice(self, shape, dtype, start_indices, update_shape,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
return [rng(shape, dtype), rng(update_shape, dtype),
|
|
|
|
onp.array(start_indices)]
|
|
|
|
|
|
|
|
self._CompileAndCheck(lax.dynamic_update_slice, args_maker,
|
|
|
|
check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, update_shape),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, update_shape in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-28 17:17:52 -08:00
|
|
|
def testDynamicUpdateSliceAgainstNumpy(self, shape, dtype, start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
update_shape, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
return [rng(shape, dtype), rng(update_shape, dtype),
|
|
|
|
onp.array(start_indices)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(lax.dynamic_update_slice,
|
|
|
|
lax_reference.dynamic_update_slice, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_perm={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), perm),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, perm in [
|
|
|
|
[(3, 4), (1, 0)],
|
|
|
|
[(3, 4), (0, 1)],
|
|
|
|
[(3, 4, 5), (2, 1, 0)],
|
|
|
|
[(3, 4, 5), (1, 0, 2)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testTranspose(self, shape, dtype, perm, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.transpose(x, perm)
|
|
|
|
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_perm={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), perm),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, perm in [
|
|
|
|
[(3, 4), (1, 0)],
|
|
|
|
[(3, 4), (0, 1)],
|
|
|
|
[(3, 4, 5), (2, 1, 0)],
|
|
|
|
[(3, 4, 5), (1, 0, 2)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testTransposeAgainstNumpy(self, shape, dtype, perm, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.transpose(x, perm)
|
|
|
|
numpy_op = lambda x: lax_reference.transpose(x, perm)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2019-07-06 12:17:00 -07:00
|
|
|
{"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}"
|
|
|
|
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
|
|
|
|
init_val),
|
2018-11-17 18:03:33 -08:00
|
|
|
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dims": dims, "rng_factory": rng_factory}
|
2019-11-15 10:02:51 -05:00
|
|
|
for init_val, op, types in [
|
2018-11-17 18:03:33 -08:00
|
|
|
(0, lax.add, default_dtypes),
|
2019-06-17 11:49:54 -07:00
|
|
|
(1, lax.mul, default_dtypes),
|
2019-08-03 21:27:06 -07:00
|
|
|
(0, lax.max, all_dtypes), # non-monoidal
|
2018-11-17 18:03:33 -08:00
|
|
|
(-onp.inf, lax.max, float_dtypes),
|
2019-11-15 10:02:51 -05:00
|
|
|
(dtypes.iinfo(onp.int32).min, lax.max, [onp.int32]),
|
|
|
|
# (dtypes.iinfo(onp.int64).min, lax.max, [onp.int64]), # TODO fails
|
|
|
|
(dtypes.iinfo(onp.uint32).min, lax.max, [onp.uint32]),
|
|
|
|
(dtypes.iinfo(onp.uint64).min, lax.max, [onp.uint64]),
|
2018-11-17 18:03:33 -08:00
|
|
|
(onp.inf, lax.min, float_dtypes),
|
2019-11-15 10:02:51 -05:00
|
|
|
(dtypes.iinfo(onp.int32).max, lax.min, [onp.int32]),
|
|
|
|
# (dtypes.iinfo(onp.int64).max, lax.min, [onp.int64]), # TODO fails
|
|
|
|
(dtypes.iinfo(onp.uint32).max, lax.min, [onp.uint32]),
|
|
|
|
(dtypes.iinfo(onp.uint64).max, lax.min, [onp.uint64]),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
2019-11-15 10:02:51 -05:00
|
|
|
for dtype in types
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, dims in [
|
|
|
|
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
|
|
|
|
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [
|
2019-11-15 10:02:51 -05:00
|
|
|
jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
|
2019-11-11 12:51:15 -08:00
|
|
|
else jtu.rand_small]))
|
|
|
|
def testReduce(self, op, init_val, shape, dtype, dims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
init_val = onp.asarray(init_val, dtype=dtype)
|
|
|
|
fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), init_val]
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
|
|
|
# we separately test the version that uses a concrete init_val because it
|
|
|
|
# can hit different code paths
|
|
|
|
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_op={}_dtype={}_padding={}"
|
|
|
|
.format(op.__name__, onp.dtype(dtype).name, padding),
|
|
|
|
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for init_val, op, dtypes in [
|
|
|
|
(0, lax.add, [onp.float32]),
|
|
|
|
(-onp.inf, lax.max, [onp.float32]),
|
|
|
|
(onp.inf, lax.min, [onp.float32]),
|
|
|
|
]
|
|
|
|
for dtype in dtypes
|
|
|
|
for padding in ["VALID", "SAME"]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testReduceWindow(self, op, init_val, dtype, padding, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
init_val = onp.asarray(init_val, dtype=dtype)
|
|
|
|
|
2019-02-03 21:10:03 -05:00
|
|
|
all_configs = itertools.chain(
|
|
|
|
itertools.product(
|
|
|
|
[(4, 6)],
|
|
|
|
[(2, 1), (1, 2)],
|
|
|
|
[(1, 1), (2, 1), (1, 2)]),
|
|
|
|
itertools.product(
|
|
|
|
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
|
|
|
[(1, 2, 2, 1), (1, 1, 1, 1)]))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def fun(operand, init_val):
|
|
|
|
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
|
|
|
|
|
|
|
# pylint: disable=cell-var-from-loop
|
|
|
|
for shape, dims, strides in all_configs:
|
|
|
|
args_maker = lambda: [rng(shape, dtype), init_val]
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
# pylint: enable=cell-var-from-loop
|
|
|
|
|
|
|
|
# we separately test the version that uses a concrete init_val because it
|
|
|
|
# can hit different code paths
|
|
|
|
def fun(operand):
|
|
|
|
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
|
|
|
|
|
|
|
# pylint: disable=cell-var-from-loop
|
|
|
|
for shape, dims, strides in all_configs:
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
# pylint: enable=cell-var-from-loop
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_axis={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis),
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in [onp.float32, onp.int32, onp.uint32]
|
|
|
|
for shape in [(5,), (5, 7)]
|
|
|
|
for axis in [-1, len(shape) - 1]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSort(self, shape, dtype, axis, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
fun = lambda x: lax.sort(x, axis)
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_axis={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis),
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in [onp.float32, onp.int32, onp.uint32]
|
|
|
|
for shape in [(5,), (5, 7)]
|
|
|
|
for axis in [-1, len(shape) - 1]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSortAgainstNumpy(self, shape, dtype, axis, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.sort(x, axis)
|
|
|
|
numpy_op = lambda x: lax_reference.sort(x, axis)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, key_dtype),
|
|
|
|
jtu.format_shape_dtype_string(shape, val_dtype),
|
|
|
|
axis),
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory, "shape": shape,
|
2018-11-17 18:03:33 -08:00
|
|
|
"key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis}
|
|
|
|
for key_dtype in [onp.float32, onp.int32, onp.uint32]
|
|
|
|
for val_dtype in [onp.float32, onp.int32, onp.uint32]
|
|
|
|
for shape in [(3,), (5, 3)]
|
|
|
|
for axis in [-1, len(shape) - 1]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSortKeyVal(self, shape, key_dtype, val_dtype, axis, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
# This test relies on the property that wherever keys are tied, values are
|
|
|
|
# too, since we don't guarantee the same ordering of values with equal keys.
|
|
|
|
# To avoid that case, we generate unique keys (globally in the key array).
|
|
|
|
perm_rng = onp.random.RandomState(0)
|
|
|
|
def args_maker():
|
|
|
|
flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
|
|
|
|
keys = perm_rng.permutation(flat_keys).reshape(shape)
|
|
|
|
values = rng(shape, val_dtype)
|
|
|
|
return keys, values
|
|
|
|
|
|
|
|
fun = lambda keys, values: lax.sort_key_val(keys, values, axis)
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, key_dtype),
|
|
|
|
jtu.format_shape_dtype_string(shape, val_dtype),
|
|
|
|
axis),
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory, "shape": shape,
|
2018-11-17 18:03:33 -08:00
|
|
|
"key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis}
|
|
|
|
for key_dtype in [onp.float32, onp.int32, onp.uint32]
|
|
|
|
for val_dtype in [onp.float32, onp.int32, onp.uint32]
|
|
|
|
for shape in [(3,), (5, 3)]
|
|
|
|
for axis in [-1, len(shape) - 1]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSortKeyValAgainstNumpy(self, shape, key_dtype, val_dtype, axis, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
# This test relies on the property that wherever keys are tied, values are
|
|
|
|
# too, since we don't guarantee the same ordering of values with equal keys.
|
|
|
|
# To avoid that case, we generate unique keys (globally in the key array).
|
|
|
|
perm_rng = onp.random.RandomState(0)
|
|
|
|
def args_maker():
|
|
|
|
flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
|
|
|
|
keys = perm_rng.permutation(flat_keys).reshape(shape)
|
|
|
|
values = rng(shape, val_dtype)
|
|
|
|
return keys, values
|
|
|
|
|
|
|
|
op = lambda ks, vs: lax.sort_key_val(ks, vs, axis)
|
|
|
|
numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis)
|
|
|
|
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype)),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [((3, 2), (2, 4)),
|
|
|
|
((5, 3, 2), (5, 2, 4)),
|
|
|
|
((1, 2, 2, 3), (1, 2, 3, 1))]
|
|
|
|
for dtype in float_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testBatchMatMul(self, lhs_shape, rhs_shape, dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
arg_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
self._CompileAndCheck(lax.batch_matmul, arg_maker, check_dtypes=True)
|
|
|
|
|
|
|
|
def testCollapse(self):
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def collapse_first_two(x):
|
|
|
|
return lax.collapse(x, 0, 2)
|
|
|
|
|
|
|
|
self.assertEqual((6,), collapse_first_two(onp.zeros((2, 3))).shape)
|
|
|
|
self.assertEqual((6, 4), collapse_first_two(onp.zeros((2, 3, 4))).shape)
|
|
|
|
self.assertEqual((2, 3, 4),
|
|
|
|
collapse_first_two(onp.zeros((1, 2, 3, 4))).shape)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in all_dtypes
|
|
|
|
for shape, idxs, axes in [
|
|
|
|
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
|
|
|
|
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
|
|
|
|
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
|
|
|
|
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testIndexTake(self, shape, dtype, idxs, axes, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
rand_idxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
|
|
|
|
fun = lambda src, idxs: lax.index_take(src, idxs, axes)
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2019-01-14 14:33:40 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
|
|
|
|
slice_sizes),
|
|
|
|
"shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
2019-11-11 12:51:15 -08:00
|
|
|
"slice_sizes": slice_sizes, "rng_factory": rng_factory,
|
|
|
|
"rng_idx_factory": rng_idx_factory}
|
2019-01-14 14:33:40 -05:00
|
|
|
for dtype in all_dtypes
|
|
|
|
for shape, idxs, dnums, slice_sizes in [
|
2019-03-01 12:19:00 -05:00
|
|
|
((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1,)),
|
|
|
|
((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
|
|
(2,)),
|
|
|
|
((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1, 3)),
|
2019-02-11 11:21:29 -08:00
|
|
|
((10, 5), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
2019-03-01 12:19:00 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
|
|
|
(1, 3)),
|
2019-01-14 14:33:40 -05:00
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
2019-11-22 11:34:14 -08:00
|
|
|
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_idx_factory):
|
|
|
|
rng = rng_factory()
|
|
|
|
rng_idx = rng_idx_factory()
|
2019-01-14 14:33:40 -05:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
|
|
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums,
|
|
|
|
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
2019-01-14 14:33:40 -05:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2019-03-01 12:19:00 -05:00
|
|
|
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2019-01-14 14:33:40 -05:00
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, max(arg_shape))]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
|
|
|
rng = rng_factory()
|
|
|
|
rng_idx = rng_idx_factory()
|
2019-01-14 14:33:40 -05:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter_add, dimension_numbers=dnums)
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2019-06-21 19:31:41 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums,
|
|
|
|
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
2019-06-21 19:31:41 -07:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
|
|
|
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, max(arg_shape))]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
|
|
|
rng = rng_factory()
|
|
|
|
rng_idx = rng_idx_factory()
|
2019-06-21 19:31:41 -07:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter_min, dimension_numbers=dnums)
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums,
|
|
|
|
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
2019-06-21 19:31:41 -07:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
|
|
|
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, max(arg_shape))]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
|
|
|
rng = rng_factory()
|
|
|
|
rng_idx = rng_idx_factory()
|
2019-06-21 19:31:41 -07:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter_max, dimension_numbers=dnums)
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2019-03-01 15:41:49 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums,
|
|
|
|
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
2019-03-01 15:41:49 -05:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
|
|
|
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, max(arg_shape))]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatter(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
|
|
|
rng = rng_factory()
|
|
|
|
rng_idx = rng_idx_factory()
|
2019-03-01 15:41:49 -05:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter, dimension_numbers=dnums)
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
|
|
|
|
2019-05-13 17:07:49 -04:00
|
|
|
def testLongConstantHandling(self):
|
|
|
|
if six.PY3:
|
|
|
|
self.skipTest("Test is Python 2 specific")
|
2019-08-05 00:22:41 +02:00
|
|
|
self.assertTrue(api.jit(lambda x: lax.lt(x, long(10)))(long(3))) # noqa: F821
|
2019-05-13 17:07:49 -04:00
|
|
|
|
2019-06-09 09:49:16 -07:00
|
|
|
def testIssue831(self):
|
|
|
|
# Tests the DeviceTuple constant handler
|
|
|
|
def f(x):
|
|
|
|
g = lambda *args: args[1]
|
|
|
|
return api.jit(lax.fori_loop, static_argnums=(2,))( 0, 10, g, x)
|
|
|
|
|
|
|
|
api.jit(f)(1.) # doesn't crash
|
|
|
|
|
2019-08-31 21:23:39 -07:00
|
|
|
def testReshapeWithUnusualShapes(self):
|
|
|
|
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
|
|
|
|
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
|
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
|
|
|
"Shapes must be 1D sequences of concrete values of integer type.*",
|
|
|
|
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)))
|
2019-08-31 21:23:39 -07:00
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
|
|
|
"Shapes must be 1D sequences of concrete values of integer type.*",
|
|
|
|
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)))
|
2019-08-31 21:23:39 -07:00
|
|
|
|
2019-11-14 16:51:39 -08:00
|
|
|
@jtu.skip_on_devices("tpu") # S16 not supported on TPU
|
2019-11-14 15:51:27 -05:00
|
|
|
def testDynamicSliceTypeErrors(self):
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 15:51:27 -05:00
|
|
|
TypeError,
|
|
|
|
"index arguments to dynamic_slice must be integers of the same type",
|
|
|
|
lambda: lax.dynamic_slice(onp.ones((3, 4), dtype=onp.float32),
|
|
|
|
(onp.int32(1), onp.int16(2)), (2, 2)))
|
|
|
|
|
2019-11-14 16:51:39 -08:00
|
|
|
@jtu.skip_on_devices("tpu") # S16 not supported on TPU
|
2019-11-14 15:51:27 -05:00
|
|
|
def testDynamicUpdateSliceTypeErrors(self):
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 15:51:27 -05:00
|
|
|
TypeError,
|
|
|
|
"index arguments to dynamic_update_slice must be integers of the same "
|
|
|
|
"type",
|
|
|
|
lambda: lax.dynamic_update_slice(onp.ones((3, 4), dtype=onp.float32),
|
|
|
|
onp.zeros((2, 2), dtype=onp.float32),
|
|
|
|
(onp.int32(1), onp.int16(2))))
|
2019-01-14 14:33:40 -05:00
|
|
|
|
2018-12-18 22:45:34 -08:00
|
|
|
class DeviceConstantTest(jtu.JaxTestCase):
|
|
|
|
def _CheckDeviceConstant(self, make_const, expected):
|
|
|
|
# check casting to ndarray works
|
|
|
|
asarray_result = onp.asarray(make_const())
|
|
|
|
|
|
|
|
# check passing as an argument works (should hit constant handler)
|
|
|
|
zero = onp.array(0, expected.dtype)
|
|
|
|
argument_result = lax.add(zero, make_const())
|
|
|
|
|
|
|
|
# check looping into a compiled computation works
|
|
|
|
jit_result = api.jit(lambda x: lax.add(x, make_const()))(zero)
|
|
|
|
|
|
|
|
# ensure they're all the same
|
|
|
|
self.assertAllClose(asarray_result, expected, check_dtypes=True)
|
|
|
|
self.assertAllClose(argument_result, expected, check_dtypes=True)
|
|
|
|
self.assertAllClose(jit_result, expected, check_dtypes=True)
|
|
|
|
|
2019-05-17 12:38:45 -07:00
|
|
|
# ensure repr doesn't crash
|
|
|
|
repr(make_const())
|
|
|
|
|
2018-12-18 22:45:34 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_fill={}".format(
|
2019-02-27 12:20:30 -08:00
|
|
|
jtu.format_shape_dtype_string(shape, dtype) if dtype else shape,
|
|
|
|
fill_value),
|
2018-12-18 22:45:34 -08:00
|
|
|
"shape": shape, "dtype": dtype, "fill_value": fill_value}
|
2019-02-27 12:20:30 -08:00
|
|
|
for dtype in itertools.chain(default_dtypes, [None])
|
2019-05-17 13:01:45 -07:00
|
|
|
for shape in [(), (3,), (2, 3), (2, 3, 4), (1001, 1001)]
|
2018-12-18 22:45:34 -08:00
|
|
|
for fill_value in [0, 1, onp.pi]))
|
|
|
|
def testFilledConstant(self, shape, fill_value, dtype):
|
|
|
|
make_const = lambda: lax.full(shape, fill_value, dtype)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
expected = onp.full(shape, fill_value,
|
|
|
|
dtype or dtypes.result_type(fill_value))
|
2018-12-18 22:45:34 -08:00
|
|
|
self._CheckDeviceConstant(make_const, expected)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_dim={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), dimension),
|
|
|
|
"shape": shape, "dtype": dtype, "dimension": dimension}
|
|
|
|
for dtype in default_dtypes
|
2019-05-19 12:44:51 -07:00
|
|
|
for shape in [(), (3,), (2, 3), (2, 3, 4),
|
|
|
|
# TODO(mattjj): re-enable
|
|
|
|
# (1001, 1001), (101, 101, 101),
|
|
|
|
]
|
2018-12-18 22:45:34 -08:00
|
|
|
for dimension in range(len(shape))))
|
|
|
|
def testIotaConstant(self, dtype, shape, dimension):
|
|
|
|
make_const = lambda: lax.broadcasted_iota(dtype, shape, dimension)
|
|
|
|
|
2019-11-15 10:02:51 -05:00
|
|
|
arr = onp.arange(shape[dimension], dtype=dtypes.canonicalize_dtype(dtype))
|
2018-12-18 22:45:34 -08:00
|
|
|
singleton_shape = [1] * len(shape)
|
|
|
|
singleton_shape[dimension] = shape[dimension]
|
|
|
|
expected = onp.broadcast_to(arr.reshape(singleton_shape), shape)
|
|
|
|
|
|
|
|
self._CheckDeviceConstant(make_const, expected)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_axes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axes),
|
|
|
|
"shape": shape, "dtype": dtype, "axes": axes}
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for shape, axes in [
|
|
|
|
[(2, 3), (0, 1)],
|
|
|
|
[(2, 3, 4), (0, 1)],
|
|
|
|
[(2, 3, 4), (0, 2)],
|
|
|
|
[(2, 3, 4), (1, 2)],
|
|
|
|
[(2, 3, 4), (0, 1, 2)],
|
|
|
|
[(2, 3, 4, 2), (0, 1, 2)],
|
|
|
|
[(2, 3, 4, 2), (0, 2, 3)],
|
2019-05-17 13:01:45 -07:00
|
|
|
[(1001, 1001), (0, 1)],
|
2018-12-18 22:45:34 -08:00
|
|
|
]))
|
|
|
|
def testEyeConstant(self, dtype, shape, axes):
|
|
|
|
make_const = lambda: lax.broadcasted_eye(dtype, shape, axes)
|
|
|
|
|
|
|
|
# don't check the asarray case, just assume it's right
|
|
|
|
expected = onp.asarray(make_const())
|
|
|
|
|
|
|
|
self._CheckDeviceConstant(make_const, expected)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
GradTestSpec = collections.namedtuple(
|
2019-11-11 12:51:15 -08:00
|
|
|
"GradTestSpec",
|
|
|
|
["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"])
|
|
|
|
def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
|
|
|
|
return GradTestSpec(
|
|
|
|
op, nargs, order, rng_factory, dtypes, name or op.__name__, tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-12-06 14:49:27 -05:00
|
|
|
grad_float_dtypes = list(jtu.supported_dtypes().intersection(
|
|
|
|
{onp.float32, onp.float64}))
|
|
|
|
grad_complex_dtypes = list(jtu.supported_dtypes().intersection(
|
|
|
|
{onp.complex64, onp.complex128}))
|
|
|
|
grad_inexact_dtypes = grad_float_dtypes + grad_complex_dtypes
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
LAX_GRAD_OPS = [
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.neg, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.floor, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.ceil, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.round, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-05-21 17:22:33 -07:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.expm1, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.log, nargs=1, order=2, rng_factory=jtu.rand_positive,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.log1p, nargs=1, order=2, rng_factory=jtu.rand_positive,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.sinh, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes + [onp.complex64], tol=1e-5),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.cosh, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes, tol=1e-5),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.tanh, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes, tol=1e-5),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.sin, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.cos, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.tan, nargs=1, order=2,
|
|
|
|
rng_factory=partial(jtu.rand_uniform, -1.3, 1.3),
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes, tol=1e-3),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.asin, nargs=1, order=2,
|
|
|
|
rng_factory=partial(jtu.rand_uniform, -1.3, 1.3),
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes, tol=1e-3),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.acos, nargs=1, order=2,
|
|
|
|
rng_factory=partial(jtu.rand_uniform, -1.3, 1.3),
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes, tol=1e-3),
|
2019-04-17 19:53:06 -04:00
|
|
|
# TODO(proteneer): atan2 input is already a representation of a
|
|
|
|
# complex number. Need to think harder about what this even means
|
|
|
|
# if each input itself is a complex number.
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.atan2, nargs=2, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-05-21 17:22:33 -07:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.erf, nargs=1, order=2, rng_factory=jtu.rand_small,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.erfc, nargs=1, order=2, rng_factory=jtu.rand_small,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.erf_inv, nargs=1, order=2, rng_factory=jtu.rand_small,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
# grad_test_spec(lax.lgamma, nargs=1, order=2, rng_factory=jtu.rand_small,
|
2019-12-06 14:49:27 -05:00
|
|
|
# dtypes=grad_float_dtypes), # TODO(mattjj): enable
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.bessel_i0e, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.bessel_i1e, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-05-21 17:22:33 -07:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.real, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_complex_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.imag, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_complex_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.complex, nargs=2, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.conj, nargs=1, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.abs, nargs=1, order=2, rng_factory=jtu.rand_positive,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.pow, nargs=2, order=2, rng_factory=jtu.rand_positive,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-05-21 17:22:33 -07:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.sub, nargs=2, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.mul, nargs=2, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.div, nargs=2, order=1, rng_factory=jtu.rand_not_small,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_inexact_dtypes),
|
2019-05-21 17:22:33 -07:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.max, nargs=2, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-11-11 12:51:15 -08:00
|
|
|
grad_test_spec(lax.min, nargs=2, order=2, rng_factory=jtu.rand_default,
|
2019-12-06 14:49:27 -05:00
|
|
|
dtypes=grad_float_dtypes),
|
2019-05-21 17:22:33 -07:00
|
|
|
# TODO(mattjj): make some-equal checks more robust, enable second-order
|
2019-11-11 12:51:15 -08:00
|
|
|
# grad_test_spec(lax.max, nargs=2, order=1, rng_factory=jtu.rand_some_equal,
|
2019-12-06 14:49:27 -05:00
|
|
|
# dtypes=grad_float_dtypes, name="MaxSomeEqual"),
|
2019-11-11 12:51:15 -08:00
|
|
|
# grad_test_spec(lax.min, nargs=2, order=1, rng_factory=jtu.rand_some_equal,
|
2019-12-06 14:49:27 -05:00
|
|
|
# dtypes=grad_float_dtypes, name="MinSomeEqual"),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
|
|
|
|
2019-08-31 21:23:39 -07:00
|
|
|
GradSpecialValuesTestSpec = collections.namedtuple(
|
2019-11-16 13:51:42 -05:00
|
|
|
"GradSpecialValuesTestSpec", ["op", "values", "tol"])
|
|
|
|
def grad_special_values_test_spec(op, values, tol=None):
|
|
|
|
return GradSpecialValuesTestSpec(op, values, tol)
|
2019-08-31 21:23:39 -07:00
|
|
|
|
|
|
|
LAX_GRAD_SPECIAL_VALUE_TESTS = [
|
2019-11-16 13:51:42 -05:00
|
|
|
grad_special_values_test_spec(
|
|
|
|
lax.sinh, [0.],
|
|
|
|
tol={onp.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
|
|
|
|
grad_special_values_test_spec(
|
|
|
|
lax.cosh, [0.],
|
|
|
|
tol={onp.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
|
|
|
|
grad_special_values_test_spec(lax.tanh, [0., 1000.]),
|
|
|
|
grad_special_values_test_spec(lax.sin, [0., onp.pi, onp.pi/2., onp.pi/4.]),
|
|
|
|
grad_special_values_test_spec(lax.cos, [0., onp.pi, onp.pi/2., onp.pi/4.]),
|
|
|
|
grad_special_values_test_spec(lax.tan, [0.]),
|
|
|
|
grad_special_values_test_spec(lax.asin, [0.]),
|
|
|
|
grad_special_values_test_spec(lax.acos, [0.]),
|
|
|
|
grad_special_values_test_spec(lax.atan, [0., 1000.]),
|
|
|
|
grad_special_values_test_spec(lax.erf, [0., 10.]),
|
|
|
|
grad_special_values_test_spec(lax.erfc, [0., 10.]),
|
2019-08-31 21:23:39 -07:00
|
|
|
]
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-21 17:22:33 -07:00
|
|
|
def check_grads_bilinear(f, args, order,
|
|
|
|
modes=["fwd", "rev"], atol=None, rtol=None):
|
2018-11-17 18:03:33 -08:00
|
|
|
# Can use large eps to make up for numerical inaccuracies since the op is
|
|
|
|
# bilinear (relying on the fact that we only check one arg at a time)
|
|
|
|
lhs, rhs = args
|
2019-05-21 17:22:33 -07:00
|
|
|
check_grads(lambda lhs: f(lhs, rhs), (lhs,), order,
|
|
|
|
modes=modes, atol=atol, rtol=rtol, eps=1.)
|
|
|
|
check_grads(lambda rhs: f(lhs, rhs), (rhs,), order,
|
|
|
|
modes=modes, atol=atol, rtol=rtol, eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
class LaxAutodiffTest(jtu.JaxTestCase):
|
|
|
|
|
2018-12-11 08:54:35 -08:00
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": jtu.format_test_name_suffix(
|
2019-05-21 17:22:33 -07:00
|
|
|
rec.name, shapes, itertools.repeat(dtype)),
|
2019-11-11 12:51:15 -08:00
|
|
|
"op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
|
2019-05-28 21:48:15 -07:00
|
|
|
"order": rec.order, "tol": rec.tol}
|
2018-12-11 08:54:35 -08:00
|
|
|
for shape_group in compatible_shapes
|
|
|
|
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
|
|
|
for dtype in rec.dtypes)
|
|
|
|
for rec in LAX_GRAD_OPS))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
|
|
|
|
rng = rng_factory()
|
2019-08-31 21:23:39 -07:00
|
|
|
if jtu.device_under_test() == "tpu" and op is lax.pow:
|
|
|
|
raise SkipTest("pow grad imprecise on tpu")
|
2019-05-28 21:48:15 -07:00
|
|
|
tol = 1e-1 if num_float_bits(dtype) == 32 else tol
|
2018-11-17 18:03:33 -08:00
|
|
|
args = tuple(rng(shape, dtype) for shape in shapes)
|
2019-05-21 17:22:33 -07:00
|
|
|
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-08-31 21:23:39 -07:00
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
|
2019-11-16 13:51:42 -05:00
|
|
|
"op": rec.op, "special_value": special_value, "tol": rec.tol}
|
2019-08-31 21:23:39 -07:00
|
|
|
for special_value in rec.values)
|
|
|
|
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
|
2019-11-16 13:51:42 -05:00
|
|
|
def testOpGradSpecialValue(self, op, special_value, tol):
|
|
|
|
check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)
|
2019-08-31 21:23:39 -07:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
|
2019-03-18 14:15:34 -07:00
|
|
|
jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
|
2019-11-11 12:51:15 -08:00
|
|
|
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
2019-03-18 14:15:34 -07:00
|
|
|
float_dtypes + complex_dtypes, repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-11-20 22:43:46 -05:00
|
|
|
tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
|
|
|
|
jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
|
2018-11-17 18:03:33 -08:00
|
|
|
args = (rng((2, 3), from_dtype),)
|
|
|
|
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
|
2019-10-22 19:53:59 -04:00
|
|
|
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(min_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(operand_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(max_shape, dtype)),
|
|
|
|
"min_shape": min_shape, "operand_shape": operand_shape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"max_shape": max_shape, "dtype": dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for min_shape, operand_shape, max_shape in [
|
|
|
|
[(), (), ()],
|
|
|
|
[(), (2, 3), ()],
|
|
|
|
[(2, 3), (2, 3), (2, 3)],
|
|
|
|
]
|
2019-11-20 22:43:46 -05:00
|
|
|
# TODO(phawkins): this test fails for bfloat16.
|
|
|
|
for dtype in [t for t in float_dtypes if t != dtypes.bfloat16]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testClampGrad(self, min_shape, operand_shape, max_shape, dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-11-20 22:43:46 -05:00
|
|
|
tol = {dtypes.bfloat16: 1e-1, onp.float16: 1e-1, onp.float32: 1e-2}
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [min_shape, operand_shape, max_shape]
|
|
|
|
min, operand, max = (rng(shape, dtype) for shape in shapes)
|
|
|
|
min, max = onp.minimum(min, max), onp.maximum(min, max) # broadcast
|
2019-11-20 22:43:46 -05:00
|
|
|
eps = 1e-1 if dtypes.finfo(dtype).bits == 16 else 1e-2
|
2019-10-22 19:53:59 -04:00
|
|
|
check_grads(lax.clamp, (min, operand, max), 2, ["fwd", "rev"], tol, tol,
|
|
|
|
eps=eps)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
|
|
|
dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
|
|
|
|
num_arrs),
|
|
|
|
"dim": dim, "base_shape": base_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"num_arrs": num_arrs, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for num_arrs in [3]
|
|
|
|
for dtype in float_dtypes
|
|
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
for dim in range(len(base_shape))
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
|
|
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
|
|
|
|
operands = tuple(rng(shape, dtype) for shape in shapes)
|
|
|
|
concatenate = lambda *args: lax.concatenate(args, dim)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
strides, padding),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"strides": strides, "padding": padding, "rng_factory": rng_factory,}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape, all_strides in itertools.chain(
|
|
|
|
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)],
|
|
|
|
[((4, 2, 1), (3, 2, 1), [(1,)])])
|
|
|
|
for strides in all_strides
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
for padding in ["VALID", "SAME"]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
|
|
rhs = rng(rhs_shape, dtype)
|
2019-06-28 09:00:32 -04:00
|
|
|
conv = partial(lax.conv, window_strides=strides, padding=padding,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2019-05-21 17:22:33 -07:00
|
|
|
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
|
|
|
atol=1e-2, rtol=1e-2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
|
|
|
|
"rhs_dilation={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
strides, padding, lhs_dil, rhs_dil),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rhs_dil": rhs_dil, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
|
|
|
|
itertools.chain(
|
|
|
|
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
|
|
|
|
[((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
|
|
|
|
[(1, 1), (2, 1)], [(1, 1)])
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)],
|
|
|
|
[((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)],
|
|
|
|
[(1,), (2,)], [(1,), (2,)])])
|
|
|
|
for strides in all_strides
|
|
|
|
for rhs_dil in rhs_dils
|
|
|
|
for lhs_dil in lhs_dils
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
for padding in all_pads
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
|
2019-11-11 12:51:15 -08:00
|
|
|
padding, lhs_dil, rhs_dil, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
|
|
rhs = rng(rhs_shape, dtype)
|
|
|
|
conv = partial(lax.conv_with_general_padding, window_strides=strides,
|
2019-06-28 09:00:32 -04:00
|
|
|
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2019-05-21 17:22:33 -07:00
|
|
|
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
|
|
|
atol=1e-2, rtol=1e-2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
|
2019-06-15 13:38:55 -07:00
|
|
|
"rhs_dilation={}_dims={}_feature_group_count={}"
|
2018-11-17 18:03:33 -08:00
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
2019-06-15 13:38:55 -07:00
|
|
|
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
|
|
|
|
feature_group_count),
|
2018-11-17 18:03:33 -08:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
|
2019-06-15 13:38:55 -07:00
|
|
|
"perms": perms, "feature_group_count": feature_group_count}
|
2019-12-02 14:43:43 -05:00
|
|
|
for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
|
|
|
|
([(b, i, 6, 7), (b, i, 0, 4)], # lhs_shape
|
2018-12-14 18:40:50 -08:00
|
|
|
(j, i, 1, 2), # rhs_shape
|
|
|
|
[(1, 1), (1, 2), (2, 1)], # strides
|
|
|
|
[(1, 1), (2, 1)], # lhs_dils
|
2019-04-03 06:58:16 -07:00
|
|
|
[(1, 1), (2, 2)]) # rhs_dils
|
2018-12-14 18:40:50 -08:00
|
|
|
for b, i, j in itertools.product([1, 2], repeat=3)]
|
2019-12-02 14:43:43 -05:00
|
|
|
for lhs_shape in lhs_shapes
|
2019-06-17 20:40:31 +02:00
|
|
|
for feature_group_count in [1, 2]
|
2018-11-17 18:03:33 -08:00
|
|
|
for strides in all_strides
|
|
|
|
for rhs_dil in rhs_dils
|
|
|
|
for lhs_dil in lhs_dils
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2019-12-02 14:43:43 -05:00
|
|
|
for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] +
|
|
|
|
([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
|
2018-11-17 18:03:33 -08:00
|
|
|
for dim_nums, perms in [
|
|
|
|
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
2019-02-15 13:16:27 -05:00
|
|
|
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
|
2019-06-15 12:01:20 -07:00
|
|
|
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]
|
2019-06-15 12:01:20 -07:00
|
|
|
))
|
2019-06-28 10:20:56 -04:00
|
|
|
@jtu.skip_on_devices("tpu") # TODO(phawkins): precision problems on TPU.
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
|
|
|
|
padding, lhs_dil, rhs_dil, dimension_numbers,
|
2019-11-11 12:51:15 -08:00
|
|
|
perms, feature_group_count, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-12-02 15:01:49 -05:00
|
|
|
tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 1e-4}
|
2019-06-15 13:38:55 -07:00
|
|
|
|
|
|
|
# permute shapes to match dim_spec, scale by feature_group_count
|
|
|
|
lhs_perm, rhs_perm = perms
|
|
|
|
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
|
|
|
|
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
|
|
|
|
dim_spec = lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)
|
|
|
|
lhs_shape[dim_spec.lhs_spec[1]] *= feature_group_count
|
|
|
|
rhs_shape[dim_spec.rhs_spec[0]] *= feature_group_count
|
|
|
|
|
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
|
|
rhs = rng(rhs_shape, dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
conv = partial(lax.conv_general_dilated, window_strides=strides,
|
|
|
|
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
|
2019-06-15 13:38:55 -07:00
|
|
|
dimension_numbers=dimension_numbers,
|
2019-06-28 09:00:32 -04:00
|
|
|
feature_group_count=feature_group_count,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2019-05-21 17:22:33 -07:00
|
|
|
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
|
|
|
atol=tol, rtol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype)),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": jtu.rand_default}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)]
|
2018-12-06 18:37:59 -05:00
|
|
|
for dtype in float_dtypes))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-11-20 22:43:46 -05:00
|
|
|
tol = {onp.float16: 1e-1, onp.float32: 1e-4}
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
|
|
rhs = rng(rhs_shape, dtype)
|
2019-06-28 09:00:32 -04:00
|
|
|
dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
|
|
|
check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
2019-05-21 17:22:33 -07:00
|
|
|
atol=tol, rtol=tol)
|
2019-10-09 17:02:11 -07:00
|
|
|
# check that precision config is preserved
|
|
|
|
result, pullback = api.vjp(dot, lhs, rhs)
|
|
|
|
gresult = lax.zeros_like_array(result)
|
|
|
|
s = str(api.make_jaxpr(pullback)(gresult))
|
|
|
|
assert "precision=HIGHEST" in s
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape, dimension_numbers in [
|
|
|
|
((3, 2), (2, 4), (([1], [0]), ([], []))),
|
|
|
|
((3, 5), (2, 5), (([1], [1]), ([], []))),
|
|
|
|
((5, 3), (5, 2), (([0], [0]), ([], []))),
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
|
|
|
]
|
2018-12-06 18:37:59 -05:00
|
|
|
for dtype in float_dtypes))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimension_numbers, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
|
|
rhs = rng(rhs_shape, dtype)
|
2019-06-28 09:00:32 -04:00
|
|
|
dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
|
2019-10-09 17:02:11 -07:00
|
|
|
# check that precision config is preserved
|
|
|
|
result, pullback = api.vjp(dot_general, lhs, rhs)
|
|
|
|
gresult = lax.zeros_like_array(result)
|
|
|
|
s = str(api.make_jaxpr(pullback)(gresult))
|
|
|
|
assert "precision=HIGHEST" in s
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
|
|
|
|
shape, onp.dtype(dtype).name, broadcast_sizes),
|
|
|
|
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(), (2, 3)]
|
|
|
|
for dtype in float_dtypes
|
|
|
|
for broadcast_sizes in [(), (2,), (1, 2)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
args = (rng(shape, dtype),)
|
|
|
|
broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(inshape, dtype),
|
|
|
|
outshape, broadcast_dimensions),
|
|
|
|
"inshape": inshape, "dtype": dtype, "outshape": outshape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimensions": broadcast_dimensions, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for inshape, outshape, broadcast_dimensions in [
|
|
|
|
([2], [2, 2], [0]),
|
|
|
|
([2], [2, 2], [1]),
|
|
|
|
([2], [2, 3], [0]),
|
|
|
|
([], [2, 3], []),
|
|
|
|
]
|
|
|
|
for dtype in float_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
operand = rng(inshape, dtype)
|
|
|
|
broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2019-06-17 11:49:54 -07:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}_perm={}".format(
|
2018-11-17 18:03:33 -08:00
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
2019-06-17 11:49:54 -07:00
|
|
|
jtu.format_shape_dtype_string(out_shape, dtype),
|
|
|
|
permutation),
|
2018-11-17 18:03:33 -08:00
|
|
|
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory, "permutation": permutation}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in float_dtypes
|
2019-06-17 11:49:54 -07:00
|
|
|
for arg_shape, out_shape, permutation in [
|
|
|
|
[(3, 4), (12,), None],
|
|
|
|
[(2, 1, 4), (8,), None],
|
|
|
|
[(2, 2, 4), (2, 8), None],
|
|
|
|
[(3, 4), (12,), (0, 1)],
|
|
|
|
[(3, 4), (12,), (1, 0)],
|
|
|
|
[(2, 1, 4), (8,), (0, 2, 1)],
|
|
|
|
[(2, 1, 4), (8,), (2, 0, 1)],
|
|
|
|
[(2, 2, 4), (2, 8), (0, 2, 1)],
|
|
|
|
[(2, 2, 4), (2, 8), (2, 0, 1)],
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
operand = rng(arg_shape, dtype)
|
2019-06-17 11:49:54 -07:00
|
|
|
reshape = lambda x: lax.reshape(x, out_shape, permutation)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_pads={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(2, 3)]
|
|
|
|
for dtype in float_dtypes
|
2019-03-14 16:35:23 +00:00
|
|
|
for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testPadGrad(self, shape, dtype, pads, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
operand = rng(shape, dtype)
|
|
|
|
padding_value = onp.array(0., dtype)
|
|
|
|
pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def testReverseGrad(self):
|
|
|
|
rev = lambda operand: lax.rev(operand, dimensions)
|
|
|
|
|
|
|
|
dimensions = [0]
|
|
|
|
check_grads(rev, (onp.array([3., 2., 1.]),), 2)
|
|
|
|
|
|
|
|
dimensions = [0, 1]
|
2019-11-16 13:51:42 -05:00
|
|
|
check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
|
|
|
|
rtol={onp.float32: 3e-3})
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_predshape={}_argshapes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype)),
|
|
|
|
"pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape in [(), (3,), (2, 3)]
|
|
|
|
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
|
|
|
for dtype in float_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
pred = rng(pred_shape, onp.bool_)
|
|
|
|
on_true = rng(arg_shape, dtype)
|
|
|
|
on_false = rng(arg_shape, dtype)
|
|
|
|
select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, limit_indices, strides),
|
|
|
|
"shape": shape, "dtype": dtype, "starts": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"limits": limit_indices, "strides": strides, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, limit_indices, strides in [
|
|
|
|
[(3,), (1,), (2,), None],
|
|
|
|
[(7,), (4,), (7,), None],
|
|
|
|
[(5,), (1,), (5,), (2,)],
|
|
|
|
[(8,), (1,), (6,), (2,)],
|
|
|
|
[(5, 3), (1, 1), (3, 2), None],
|
|
|
|
[(5, 3), (1, 1), (3, 1), None],
|
|
|
|
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
|
|
|
|
[(5, 3), (1, 1), (2, 1), (1, 1)],
|
|
|
|
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
|
|
|
]
|
|
|
|
for dtype in float_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
slice = lambda x: lax.slice(x, starts, limits, strides)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, size_indices),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"size_indices": size_indices, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, size_indices in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
for dtype in float_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, update_shape),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, update_shape in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
for dtype in float_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
update_shape, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
update = rng(update_shape, dtype)
|
|
|
|
start_indices = onp.array(start_indices)
|
2018-11-28 17:17:52 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-28 17:17:52 -08:00
|
|
|
dus = lambda x: lax.dynamic_update_slice(x, update, start_indices)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.)
|
2018-11-28 17:17:52 -08:00
|
|
|
|
|
|
|
dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)
|
2018-11-28 17:17:52 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_perm={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), perm),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, perm in [
|
|
|
|
[(3, 4), (1, 0)],
|
|
|
|
[(3, 4), (0, 1)],
|
|
|
|
[(3, 4, 5), (2, 1, 0)],
|
|
|
|
[(3, 4, 5), (1, 0, 2)],
|
|
|
|
]
|
|
|
|
for dtype in float_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testTransposeGrad(self, shape, dtype, perm, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
transpose = lambda x: lax.transpose(x, perm)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_op={}_inshape={}_reducedims={}"
|
|
|
|
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
|
|
|
|
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dims": dims, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for init_val, op, dtypes in [
|
2019-02-01 11:07:45 -05:00
|
|
|
(0, lax.add, inexact_dtypes),
|
2019-10-22 19:53:59 -04:00
|
|
|
# Precision problems for float16 tests.
|
|
|
|
(-onp.inf, lax.max, [t for t in inexact_dtypes if t != onp.float16]),
|
|
|
|
(onp.inf, lax.min, [t for t in inexact_dtypes if t != onp.float16]),
|
|
|
|
# The mul test overflows the range of a float16.
|
2019-11-20 22:43:46 -05:00
|
|
|
(1, lax.mul, [t for t in inexact_dtypes
|
|
|
|
if t not in (onp.float16, dtypes.bfloat16)]),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
|
|
|
for dtype in dtypes
|
|
|
|
for shape, dims in [
|
2019-05-05 14:31:46 -04:00
|
|
|
[(3, 4, 5), ()],
|
2018-11-17 18:03:33 -08:00
|
|
|
[(3, 4, 5), (0,)],
|
|
|
|
[(3, 4, 5), (1, 2)],
|
|
|
|
[(3, 4, 5), (0, 2)],
|
2019-06-17 11:49:54 -07:00
|
|
|
[(3, 4, 5), (0, 1, 2)],
|
|
|
|
[(3, 1), (1,)],
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-08-04 17:17:49 -04:00
|
|
|
if jtu.device_under_test() == "tpu" and op is lax.mul:
|
2019-05-06 12:30:22 -07:00
|
|
|
raise SkipTest("unimplemented case")
|
2019-11-20 22:43:46 -05:00
|
|
|
tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 4e-2,
|
|
|
|
onp.float64: 1e-3, onp.complex64: 1e-2}
|
2018-11-17 18:03:33 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
init_val = onp.asarray(init_val, dtype=dtype)
|
|
|
|
reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
|
2019-11-15 10:02:51 -05:00
|
|
|
eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
|
2019-11-20 22:43:46 -05:00
|
|
|
1e-1 if dtype == dtypes.bfloat16 else
|
2019-11-15 10:02:51 -05:00
|
|
|
1e-2 if dtypes.finfo(dtype).bits == 32 else None)
|
2019-10-22 19:53:59 -04:00
|
|
|
check_grads(reduce, (operand,), 1, ["fwd", "rev"], tol, tol, eps)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_op={}_dtype={}_padding={}"
|
|
|
|
.format(op.__name__, onp.dtype(dtype).name, padding),
|
|
|
|
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for init_val, op, dtypes, rng in [
|
2019-11-11 12:51:15 -08:00
|
|
|
(0, lax.add, float_dtypes, jtu.rand_small),
|
|
|
|
(-onp.inf, lax.max, [onp.float32], jtu.rand_default),
|
|
|
|
(onp.inf, lax.min, [onp.float32], jtu.rand_default),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
|
|
|
for dtype in dtypes
|
|
|
|
for padding in ["VALID", "SAME"]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-11-20 22:43:46 -05:00
|
|
|
tol = {onp.float16: 1e-1, onp.float32: 1e-3}
|
2018-11-17 18:03:33 -08:00
|
|
|
init_val = onp.asarray(init_val, dtype=dtype)
|
|
|
|
|
|
|
|
# We need this conditional and the corresponding loop logic to be in the
|
|
|
|
# test method, rather than at the parameterized test level, because it
|
|
|
|
# depends on FLAGS for the device under test.
|
2019-02-03 21:10:03 -05:00
|
|
|
# TODO(b/31565929): enable when fixed.
|
2019-08-04 17:17:49 -04:00
|
|
|
if jtu.device_under_test() == "tpu" and op is not lax.add:
|
2018-11-17 18:03:33 -08:00
|
|
|
all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]
|
2019-07-02 11:34:49 -04:00
|
|
|
|
|
|
|
# TODO(b/73062247): need variadic reduce-window for better precision.
|
|
|
|
gradient_order = 1
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
all_configs = itertools.chain(
|
|
|
|
itertools.product(
|
|
|
|
[(4, 6)], # shapes
|
|
|
|
[(2, 1), (1, 2)], # window_dimensions
|
|
|
|
[(1, 1), (2, 1), (1, 2)] # strides
|
|
|
|
),
|
|
|
|
itertools.product(
|
|
|
|
[(3, 2, 4, 6)], # shapes
|
|
|
|
[(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions
|
|
|
|
[(1, 2, 2, 1), (1, 1, 1, 1)]), # strides
|
|
|
|
)
|
2019-07-02 11:34:49 -04:00
|
|
|
gradient_order = 3
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def fun(operand):
|
|
|
|
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
|
|
|
|
|
|
|
for shape, dims, strides in all_configs:
|
|
|
|
operand = rng(shape, dtype)
|
2019-10-22 19:53:59 -04:00
|
|
|
if op is lax.add:
|
|
|
|
eps = 1.
|
|
|
|
else:
|
2018-11-17 18:03:33 -08:00
|
|
|
# this test can fail if there are duplicates in operand
|
|
|
|
self.assertEqual(onp.unique(operand).size, operand.size,
|
|
|
|
msg="test requires operand elements to be unique.")
|
2019-10-22 19:53:59 -04:00
|
|
|
eps = 1e-2
|
|
|
|
check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
|
|
|
|
eps)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# TODO(b/205052657): enable more tests when supported
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_axis={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis),
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in [onp.float32]
|
|
|
|
for shape in [(5,), (5, 7)]
|
|
|
|
for axis in [len(shape) - 1]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSortGrad(self, shape, dtype, axis, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
operand = rng(shape, dtype)
|
|
|
|
sort = lambda x: lax.sort(x, axis)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# TODO(b/205052657): enable more tests when supported
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, key_dtype),
|
|
|
|
jtu.format_shape_dtype_string(shape, val_dtype),
|
|
|
|
axis),
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory, "shape": shape,
|
2018-11-17 18:03:33 -08:00
|
|
|
"key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis}
|
|
|
|
for key_dtype in [onp.float32]
|
|
|
|
for val_dtype in [onp.float32]
|
|
|
|
for shape in [(3,), (5, 3)]
|
|
|
|
for axis in [len(shape) - 1]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
# This test relies on the property that wherever keys are tied, values are
|
|
|
|
# too, since we don't guarantee the same ordering of values with equal keys.
|
|
|
|
# To avoid that case, we generate unique keys (globally in the key array).
|
|
|
|
perm_rng = onp.random.RandomState(0)
|
|
|
|
def args_maker():
|
|
|
|
flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
|
|
|
|
keys = perm_rng.permutation(flat_keys).reshape(shape)
|
|
|
|
values = rng(shape, val_dtype)
|
|
|
|
return keys, values
|
|
|
|
keys, values = args_maker()
|
|
|
|
|
|
|
|
fun = lambda keys, values: lax.sort_key_val(keys, values, axis)
|
2019-05-21 17:22:33 -07:00
|
|
|
check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes,
|
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for shape, idxs, axes in [
|
|
|
|
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
|
|
|
|
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
|
|
|
|
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
|
|
|
|
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory):
|
|
|
|
rng = rng_factory()
|
2018-11-17 18:03:33 -08:00
|
|
|
src = rng(shape, dtype)
|
|
|
|
index_take = lambda src: lax.index_take(src, idxs, axes)
|
2019-11-20 22:43:46 -05:00
|
|
|
check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-14 14:33:40 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
|
|
|
|
slice_sizes),
|
|
|
|
"shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
2019-11-11 12:51:15 -08:00
|
|
|
"slice_sizes": slice_sizes, "rng_factory": rng_factory,
|
|
|
|
"rng_idx_factory": rng_idx_factory}
|
2019-01-14 14:33:40 -05:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for shape, idxs, dnums, slice_sizes in [
|
2019-03-01 12:19:00 -05:00
|
|
|
((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1,)),
|
|
|
|
((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
|
|
(2,)),
|
|
|
|
((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1, 3)),
|
2019-01-14 14:33:40 -05:00
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
|
|
|
|
rng_idx_factory):
|
|
|
|
rng = rng_factory()
|
|
|
|
rng_idx = rng_idx_factory()
|
2019-01-14 14:33:40 -05:00
|
|
|
idxs = rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums,
|
|
|
|
slice_sizes=slice_sizes)
|
|
|
|
x = rng(shape, dtype)
|
2019-05-21 17:22:33 -07:00
|
|
|
check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
|
2019-01-14 14:33:40 -05:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
|
|
|
|
"rng_idx_factory": rng_idx_factory}
|
2019-01-14 14:33:40 -05:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2019-03-01 12:19:00 -05:00
|
|
|
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2019-01-14 14:33:40 -05:00
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, max(arg_shape))]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
|
|
|
rng = rng_factory()
|
|
|
|
rng_idx = rng_idx_factory()
|
2019-01-14 14:33:40 -05:00
|
|
|
idxs = rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
scatter_add = lambda x, y: lax.scatter_add(x, idxs, y,
|
|
|
|
dimension_numbers=dnums)
|
|
|
|
x = rng(arg_shape, dtype)
|
|
|
|
y = rng(update_shape, dtype)
|
2019-05-21 17:22:33 -07:00
|
|
|
check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
|
2019-03-01 15:41:49 -05:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
|
|
|
|
"rng_idx_factory": rng_idx_factory}
|
2019-03-01 15:41:49 -05:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
|
|
|
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, max(arg_shape))]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
|
|
|
rng = rng_factory()
|
|
|
|
rng_idx = rng_idx_factory()
|
2019-03-01 15:41:49 -05:00
|
|
|
idxs = rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
|
|
|
|
x = rng(arg_shape, dtype)
|
|
|
|
y = rng(update_shape, dtype)
|
2019-05-21 17:22:33 -07:00
|
|
|
check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
|
2019-01-14 14:33:40 -05:00
|
|
|
|
2019-01-30 10:39:35 -08:00
|
|
|
def testStopGradient(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.sin(x) * lax.cos(lax.stop_gradient(x))
|
|
|
|
|
|
|
|
def f2(x, y):
|
|
|
|
return lax.sin(x) * lax.cos(y)
|
|
|
|
|
|
|
|
x = 3.14
|
|
|
|
ans = api.grad(f)(x)
|
|
|
|
expected = api.grad(f2)(x, x)
|
2019-01-30 10:43:57 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=True)
|
2019-01-30 10:39:35 -08:00
|
|
|
|
2019-01-30 10:43:57 -08:00
|
|
|
ans = api.grad(api.grad(f))(x)
|
|
|
|
expected = api.grad(api.grad(f2))(x, x)
|
2019-01-30 10:39:35 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=True)
|
|
|
|
|
2019-06-20 16:23:13 -07:00
|
|
|
ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
|
|
|
|
expected = onp.array(0.0)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-09-15 09:24:00 -07:00
|
|
|
# TODO(mattjj): make this a more systematic test
|
|
|
|
def testRemainder(self):
|
|
|
|
rng = onp.random.RandomState(0)
|
|
|
|
x = rng.uniform(-0.9, 9, size=(3, 4))
|
|
|
|
y = rng.uniform(0.7, 1.9, size=(3, 1))
|
|
|
|
assert not set(onp.unique(x)) & set(onp.unique(y))
|
|
|
|
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
|
|
|
|
check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)
|
|
|
|
|
|
|
|
rng = onp.random.RandomState(0)
|
|
|
|
x = rng.uniform(-0.9, 9, size=(1, 4))
|
|
|
|
y = rng.uniform(0.7, 1.9, size=(3, 4))
|
|
|
|
assert not set(onp.unique(x)) & set(onp.unique(y))
|
|
|
|
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
|
|
|
|
check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)
|
|
|
|
|
2019-01-14 14:33:40 -05:00
|
|
|
|
2019-06-17 20:44:33 -07:00
|
|
|
def all_bdims(*shapes):
|
|
|
|
bdims = (itertools.chain([None], range(len(shape) + 1)) for shape in shapes)
|
|
|
|
return (t for t in itertools.product(*bdims) if not all(e is None for e in t))
|
|
|
|
|
|
|
|
def add_bdim(bdim_size, bdim, shape):
|
|
|
|
shape = list(shape)
|
|
|
|
if bdim is not None:
|
|
|
|
shape.insert(bdim, bdim_size)
|
|
|
|
return tuple(shape)
|
|
|
|
|
2019-06-15 12:01:20 -07:00
|
|
|
def slicer(x, bdim):
|
|
|
|
if bdim is None:
|
|
|
|
return lambda _: x
|
|
|
|
else:
|
|
|
|
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
|
|
|
|
|
2019-06-17 20:44:33 -07:00
|
|
|
def args_slicer(args, bdims):
|
|
|
|
slicers = list(map(slicer, args, bdims))
|
|
|
|
return lambda i: [sl(i) for sl in slicers]
|
|
|
|
|
2019-06-15 12:01:20 -07:00
|
|
|
class LaxVmapTest(jtu.JaxTestCase):
|
|
|
|
|
2019-06-17 20:44:33 -07:00
|
|
|
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtype, rng,
|
|
|
|
rtol=None, atol=None):
|
|
|
|
batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
|
|
|
|
args = [rng(shape, dtype) for shape in batched_shapes]
|
|
|
|
args_slice = args_slicer(args, bdims)
|
|
|
|
ans = api.vmap(op, bdims)(*args)
|
|
|
|
expected = onp.stack([op(*args_slice(i)) for i in range(bdim_size)])
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=True, rtol=rtol, atol=atol)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": "{}_bdims={}".format(
|
2019-10-21 10:56:54 -04:00
|
|
|
jtu.format_test_name_suffix(rec.op, shapes,
|
2019-06-17 20:44:33 -07:00
|
|
|
itertools.repeat(dtype)), bdims),
|
2019-11-11 12:51:15 -08:00
|
|
|
"op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes,
|
|
|
|
"dtype": dtype, "bdims": bdims}
|
2019-06-17 20:44:33 -07:00
|
|
|
for shape_group in compatible_shapes
|
|
|
|
for shapes in CombosWithReplacement(shape_group, rec.nargs)
|
|
|
|
for bdims in all_bdims(*shapes)
|
|
|
|
for dtype in rec.dtypes)
|
|
|
|
for rec in LAX_OPS))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testOp(self, op_name, rng_factory, shapes, dtype, bdims):
|
|
|
|
rng = rng_factory()
|
2019-10-21 10:56:54 -04:00
|
|
|
op = getattr(lax, op_name)
|
2019-11-20 22:43:46 -05:00
|
|
|
self._CheckBatching(op, 10, bdims, shapes, dtype, rng)
|
2019-06-17 20:44:33 -07:00
|
|
|
|
2019-06-15 12:01:20 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
|
2019-06-15 13:38:55 -07:00
|
|
|
"rhs_dilation={}_dims={}_feature_group_count={}_lhs_bdim={}_rhs_bdim={}"
|
2019-06-15 12:01:20 -07:00
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
|
2019-06-15 13:38:55 -07:00
|
|
|
feature_group_count, lhs_bdim, rhs_bdim),
|
2019-06-15 12:01:20 -07:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
|
2019-06-15 13:38:55 -07:00
|
|
|
"perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim,
|
|
|
|
"feature_group_count": feature_group_count}
|
2019-06-15 12:01:20 -07:00
|
|
|
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [
|
|
|
|
((b, i, 6, 7), # lhs_shape
|
|
|
|
(j, i, 1, 2), # rhs_shape
|
|
|
|
[(1, 1), (1, 2), (2, 1)], # strides
|
|
|
|
[((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads
|
|
|
|
[(1, 1), (2, 1)], # lhs_dils
|
|
|
|
[(1, 1), (2, 2)]) # rhs_dils
|
|
|
|
for b, i, j in itertools.product([1, 2], repeat=3)]
|
2019-06-15 13:38:55 -07:00
|
|
|
for feature_group_count in [1, 2]
|
2019-06-15 12:01:20 -07:00
|
|
|
for strides in all_strides
|
|
|
|
for rhs_dil in rhs_dils
|
|
|
|
for lhs_dil in lhs_dils
|
|
|
|
for dtype in [onp.float32]
|
|
|
|
for padding in all_pads
|
|
|
|
for dim_nums, perms in [
|
|
|
|
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
|
|
|
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
|
|
|
|
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
|
|
|
|
for lhs_bdim in itertools.chain([None], range(len(lhs_shape) + 1))
|
|
|
|
for rhs_bdim in itertools.chain([None], range(len(rhs_shape) + 1))
|
|
|
|
if (lhs_bdim, rhs_bdim) != (None, None)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]
|
2019-06-15 12:01:20 -07:00
|
|
|
))
|
2019-07-21 16:25:27 -04:00
|
|
|
# TODO(mattjj): some cases fail on TPU just due to numerical tolerances
|
|
|
|
@jtu.skip_on_devices("tpu")
|
2019-06-15 12:01:20 -07:00
|
|
|
def testConvGeneralDilatedBatching(
|
|
|
|
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimension_numbers, perms, feature_group_count, lhs_bdim, rhs_bdim, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-11-15 10:02:51 -05:00
|
|
|
tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3
|
2019-06-15 12:01:20 -07:00
|
|
|
|
2019-06-15 13:38:55 -07:00
|
|
|
# permute shapes to match dim_spec, scale by feature_group_count
|
2019-06-15 12:01:20 -07:00
|
|
|
lhs_perm, rhs_perm = perms
|
|
|
|
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
|
|
|
|
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
|
2019-06-15 13:38:55 -07:00
|
|
|
dim_spec = lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)
|
|
|
|
lhs_shape[dim_spec.lhs_spec[1]] *= feature_group_count
|
|
|
|
rhs_shape[dim_spec.rhs_spec[0]] *= feature_group_count
|
2019-06-15 12:01:20 -07:00
|
|
|
|
|
|
|
conv = partial(lax.conv_general_dilated, window_strides=strides,
|
|
|
|
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
|
2019-06-15 13:38:55 -07:00
|
|
|
dimension_numbers=dimension_numbers,
|
2019-06-28 09:00:32 -04:00
|
|
|
feature_group_count=feature_group_count,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2019-06-17 20:44:33 -07:00
|
|
|
self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
|
2019-11-22 16:36:32 -05:00
|
|
|
dtype, rng, rtol=tol, atol=tol)
|
2019-06-17 20:44:33 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
|
|
|
|
shape, from_dtype, to_dtype, bdims),
|
|
|
|
"shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
|
|
|
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
|
|
|
for shape in [(2, 3)]
|
|
|
|
for bdims in all_bdims(shape)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConvertElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
op = lambda x: lax.convert_element_type(x, to_dtype)
|
|
|
|
self._CheckBatching(op, 10, bdims, (shape,), from_dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
|
|
|
|
shape, from_dtype, to_dtype, bdims),
|
|
|
|
"shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
|
|
|
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
|
|
|
|
for shape in [(2, 3)]
|
|
|
|
for bdims in all_bdims(shape)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
|
|
|
self._CheckBatching(op, 10, bdims, (shape,), from_dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(min_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(operand_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(max_shape, dtype),
|
|
|
|
bdims),
|
|
|
|
"min_shape": min_shape, "operand_shape": operand_shape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"max_shape": max_shape, "dtype": dtype, "bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for min_shape, operand_shape, max_shape in [
|
|
|
|
[(), (2, 3), ()],
|
|
|
|
[(2, 3), (2, 3), ()],
|
|
|
|
[(), (2, 3), (2, 3)],
|
|
|
|
[(2, 3), (2, 3), (2, 3)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for bdims in all_bdims(min_shape, operand_shape, max_shape)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
raise SkipTest("batching rule for clamp not implemented") # TODO(mattj)
|
|
|
|
shapes = [min_shape, operand_shape, max_shape]
|
|
|
|
self._CheckBatching(lax.clamp, 10, bdims, shapes, dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
bdims),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
|
|
|
|
for bdims in all_bdims(lhs_shape, rhs_shape)
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testDot(self, lhs_shape, rhs_shape, dtype, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-11-16 13:51:42 -05:00
|
|
|
op = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
|
|
|
self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), dtype, rng,
|
2019-11-20 22:43:46 -05:00
|
|
|
rtol={onp.float16: 5e-2})
|
2019-06-17 20:44:33 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
lhs_contracting, rhs_contracting, bdims),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
|
2019-11-11 12:51:15 -08:00
|
|
|
"bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
|
|
|
|
[(3, 5), (2, 5), [1], [1]],
|
|
|
|
[(5, 3), (5, 2), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
|
|
|
|
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
|
|
|
|
[(3, 2), (2, 4), [1], [0]],
|
|
|
|
]
|
|
|
|
for bdims in all_bdims(lhs_shape, rhs_shape)
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2019-06-17 20:44:33 -07:00
|
|
|
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
lhs_contracting, rhs_contracting, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
|
|
|
|
dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
|
|
|
|
self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers, bdims),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimension_numbers": dimension_numbers, "bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for lhs_shape, rhs_shape, dimension_numbers in [
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
|
|
|
]
|
|
|
|
for bdims in all_bdims(lhs_shape, rhs_shape)
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2019-06-17 20:44:33 -07:00
|
|
|
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimension_numbers, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
|
|
|
|
self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format(
|
|
|
|
shape, onp.dtype(dtype).name, broadcast_sizes, bdims),
|
|
|
|
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
|
2019-11-11 12:51:15 -08:00
|
|
|
"bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for shape in [(), (2, 3)]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for broadcast_sizes in [(), (2,), (1, 2)]
|
|
|
|
for bdims in all_bdims(shape)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcast(self, shape, dtype, broadcast_sizes, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
op = lambda x: lax.broadcast(x, broadcast_sizes)
|
|
|
|
self._CheckBatching(op, 5, bdims, (shape,), dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(inshape, dtype),
|
|
|
|
outshape, broadcast_dimensions, bdims),
|
|
|
|
"inshape": inshape, "dtype": dtype, "outshape": outshape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimensions": broadcast_dimensions, "bdims": bdims,
|
|
|
|
"rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for inshape, outshape, broadcast_dimensions in [
|
|
|
|
([2], [2, 2], [0]),
|
|
|
|
([2], [2, 2], [1]),
|
|
|
|
([2], [2, 3], [0]),
|
|
|
|
([], [2, 3], []),
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for bdims in all_bdims(inshape)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-07-06 11:52:24 -07:00
|
|
|
raise SkipTest("this test has failures in some cases") # TODO(mattjj)
|
2019-06-17 20:44:33 -07:00
|
|
|
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
|
|
|
self._CheckBatching(op, 5, bdims, (inshape,), dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2019-09-08 14:19:10 -07:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format(
|
2019-06-17 20:44:33 -07:00
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(out_shape, dtype),
|
2019-09-08 14:19:10 -07:00
|
|
|
dimensions, bdims),
|
2019-06-17 20:44:33 -07:00
|
|
|
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimensions": dimensions, "bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for dtype in default_dtypes
|
2019-09-08 14:19:10 -07:00
|
|
|
for arg_shape, dimensions, out_shape in [
|
|
|
|
[(3, 4), None, (12,)],
|
|
|
|
[(2, 1, 4), None, (8,)],
|
|
|
|
[(2, 2, 4), None, (2, 8)],
|
|
|
|
[(2, 2, 4), (0, 1, 2), (2, 8)],
|
|
|
|
[(2, 2, 4), (1, 0, 2), (8, 2)],
|
|
|
|
[(2, 2, 4), (2, 1, 0), (4, 2, 2)]
|
2019-06-17 20:44:33 -07:00
|
|
|
]
|
|
|
|
for bdims in all_bdims(arg_shape)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-09-08 14:19:10 -07:00
|
|
|
op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
|
2019-06-17 20:44:33 -07:00
|
|
|
self._CheckBatching(op, 10, bdims, (arg_shape,), dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_inshape={}_pads={}_bdims={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "pads": pads,
|
|
|
|
"rng_factory": jtu.rand_small, "bdims": bdims}
|
2019-06-17 20:44:33 -07:00
|
|
|
for shape in [(2, 3)]
|
|
|
|
for bdims in all_bdims(shape)
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for pads in [[(1, 2, 1), (0, 1, 0)]]))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testPad(self, shape, dtype, pads, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
|
|
|
|
self._CheckBatching(fun, 5, bdims, (shape,), dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_predshape={}_argshapes={}_bdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, arg_dtype),
|
|
|
|
bdims),
|
|
|
|
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for arg_shape in [(), (3,), (2, 3)]
|
|
|
|
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
|
|
|
for bdims in all_bdims(pred_shape, arg_shape, arg_shape)
|
|
|
|
for arg_dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
op = lambda c, x, y: lax.select(c < 0, x, y)
|
|
|
|
self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,),
|
|
|
|
arg_dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, limit_indices, strides, bdims),
|
|
|
|
"shape": shape, "dtype": dtype, "starts": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"limits": limit_indices, "strides": strides, "bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for shape, start_indices, limit_indices, strides in [
|
|
|
|
[(3,), (1,), (2,), None],
|
|
|
|
[(7,), (4,), (7,), None],
|
|
|
|
[(5,), (1,), (5,), (2,)],
|
|
|
|
[(8,), (1,), (6,), (2,)],
|
|
|
|
[(5, 3), (1, 1), (3, 2), None],
|
|
|
|
[(5, 3), (1, 1), (3, 1), None],
|
|
|
|
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
|
|
|
|
[(5, 3), (1, 1), (2, 1), (1, 1)],
|
|
|
|
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
|
|
|
]
|
|
|
|
for bdims in all_bdims(shape)
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSlice(self, shape, dtype, starts, limits, strides, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
op = lambda x: lax.slice(x, starts, limits, strides)
|
|
|
|
self._CheckBatching(op, 5, bdims, (shape,), dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_perm={}_bdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), perm, bdims),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for shape, perm in [
|
|
|
|
[(3, 4), (1, 0)],
|
|
|
|
[(3, 4), (0, 1)],
|
|
|
|
[(3, 4, 5), (2, 1, 0)],
|
|
|
|
[(3, 4, 5), (1, 0, 2)],
|
|
|
|
]
|
|
|
|
for bdims in all_bdims(shape)
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testTranspose(self, shape, dtype, perm, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
op = lambda x: lax.transpose(x, perm)
|
|
|
|
self._CheckBatching(op, 5, bdims, (shape,), dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2019-08-03 21:27:06 -07:00
|
|
|
{"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}"
|
2019-06-17 20:44:33 -07:00
|
|
|
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
|
2019-08-03 21:27:06 -07:00
|
|
|
init_val, bdims),
|
2019-06-17 20:44:33 -07:00
|
|
|
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dims": dims, "bdims": bdims, "rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for init_val, op, dtypes in [
|
|
|
|
(0, lax.add, default_dtypes),
|
|
|
|
(1, lax.mul, default_dtypes),
|
2019-08-03 21:27:06 -07:00
|
|
|
(0, lax.max, all_dtypes), # non-monoidal
|
2019-06-17 20:44:33 -07:00
|
|
|
(-onp.inf, lax.max, float_dtypes),
|
2019-11-15 10:02:51 -05:00
|
|
|
(dtypes.iinfo(onp.int32).min, lax.max, [onp.int32]),
|
|
|
|
(dtypes.iinfo(onp.int64).min, lax.max, [onp.int64]),
|
|
|
|
(dtypes.iinfo(onp.uint32).min, lax.max, [onp.uint32]),
|
|
|
|
(dtypes.iinfo(onp.uint64).min, lax.max, [onp.uint64]),
|
2019-06-17 20:44:33 -07:00
|
|
|
(onp.inf, lax.min, float_dtypes),
|
2019-11-15 10:02:51 -05:00
|
|
|
(dtypes.iinfo(onp.int32).max, lax.min, [onp.int32]),
|
|
|
|
(dtypes.iinfo(onp.int64).max, lax.min, [onp.int64]),
|
|
|
|
(dtypes.iinfo(onp.uint32).max, lax.min, [onp.uint32]),
|
|
|
|
(dtypes.iinfo(onp.uint64).max, lax.min, [onp.uint64]),
|
2019-06-17 20:44:33 -07:00
|
|
|
]
|
|
|
|
for dtype in dtypes
|
|
|
|
for shape, dims in [
|
|
|
|
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
|
|
|
|
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
|
|
|
|
]
|
|
|
|
for bdims in all_bdims(shape)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testReduce(self, op, init_val, shape, dtype, dims, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
init_val = onp.asarray(init_val, dtype=dtype)
|
|
|
|
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
|
|
|
|
self._CheckBatching(fun, 5, bdims, (shape,), dtype, rng)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_op={}_dtype={}_padding={}"
|
|
|
|
.format(op.__name__, onp.dtype(dtype).name, padding),
|
|
|
|
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2019-06-17 20:44:33 -07:00
|
|
|
for init_val, op, dtypes in [
|
|
|
|
(0, lax.add, [onp.float32]),
|
|
|
|
(-onp.inf, lax.max, [onp.float32]),
|
|
|
|
(onp.inf, lax.min, [onp.float32]),
|
|
|
|
]
|
|
|
|
for dtype in dtypes
|
|
|
|
for padding in ["VALID", "SAME"]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testReduceWindow(self, op, init_val, dtype, padding, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-06-17 20:44:33 -07:00
|
|
|
init_val = onp.asarray(init_val, dtype=dtype)
|
|
|
|
|
|
|
|
all_configs = itertools.chain(
|
|
|
|
itertools.product(
|
|
|
|
[(4, 6)],
|
|
|
|
[(2, 1), (1, 2)],
|
|
|
|
[(1, 1), (2, 1), (1, 2)]),
|
|
|
|
itertools.product(
|
|
|
|
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
|
|
|
[(1, 2, 2, 1), (1, 1, 1, 1)]))
|
|
|
|
|
|
|
|
def fun(operand):
|
|
|
|
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
|
|
|
|
|
|
|
for shape, dims, strides in all_configs:
|
|
|
|
for bdims in all_bdims(shape):
|
|
|
|
self._CheckBatching(fun, 3, bdims, (shape,), dtype, rng)
|
|
|
|
|
2019-11-21 11:52:58 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_dtype={}_padding={}".format(onp.dtype(dtype).name,
|
|
|
|
padding),
|
|
|
|
"dtype": dtype, "padding": padding, "rng_factory": rng_factory}
|
|
|
|
for dtype in float_dtypes
|
|
|
|
for padding in ["VALID", "SAME"]
|
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testSelectAndGatherAdd(self, dtype, padding, rng_factory):
|
2019-11-22 11:34:14 -08:00
|
|
|
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
|
|
|
|
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
|
2019-11-21 11:52:58 -05:00
|
|
|
rng = rng_factory()
|
|
|
|
all_configs = itertools.chain(
|
|
|
|
itertools.product(
|
|
|
|
[(4, 6)],
|
|
|
|
[(2, 1), (1, 2)],
|
|
|
|
[(1, 1), (2, 1), (1, 2)]),
|
|
|
|
itertools.product(
|
|
|
|
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
|
|
|
[(1, 2, 2, 1), (1, 1, 1, 1)]))
|
|
|
|
|
|
|
|
def fun(operand, tangents):
|
|
|
|
return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims,
|
|
|
|
strides, padding)
|
|
|
|
|
|
|
|
for shape, dims, strides in all_configs:
|
|
|
|
for bdims in all_bdims(shape, shape):
|
|
|
|
self._CheckBatching(fun, 3, bdims, (shape, shape), dtype, rng)
|
|
|
|
|
2019-07-18 12:10:12 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_bdims={}_fft_ndims={}"
|
|
|
|
.format(shape, bdims, fft_ndims),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "bdims": bdims, "fft_ndims": fft_ndims, "rng_factory": rng_factory}
|
2019-07-18 12:10:12 -07:00
|
|
|
for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
|
|
|
|
for bdims in all_bdims(shape)
|
|
|
|
for fft_ndims in range(0, min(3, len(shape)) + 1)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2019-07-20 13:56:31 +01:00
|
|
|
@jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases.
|
2019-11-11 12:51:15 -08:00
|
|
|
def testFft(self, fft_ndims, shape, bdims, rng_factory):
|
|
|
|
rng = rng_factory()
|
2019-07-18 12:10:12 -07:00
|
|
|
ndims = len(shape)
|
|
|
|
axes = range(ndims - fft_ndims, ndims)
|
|
|
|
fft_lengths = [shape[axis] for axis in axes]
|
2019-08-04 12:34:03 -04:00
|
|
|
op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
|
2019-07-18 12:10:12 -07:00
|
|
|
self._CheckBatching(op, 5, bdims, [shape], onp.complex64, rng)
|
|
|
|
|
2019-06-17 20:44:33 -07:00
|
|
|
# TODO Concatenate
|
|
|
|
# TODO Reverse
|
|
|
|
# TODO DynamicSlice
|
|
|
|
# TODO DynamicUpdateSlice
|
|
|
|
# TODO Sort
|
|
|
|
# TODO SortKeyVal
|
|
|
|
# TODO Collapse
|
|
|
|
# TODO ScatterAdd
|
|
|
|
# TODO Scatter
|
2019-06-15 12:01:20 -07:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
absltest.main()
|