2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
import collections
|
|
|
|
from functools import partial
|
|
|
|
import itertools
|
2020-06-01 11:49:35 -07:00
|
|
|
from unittest import SkipTest
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-18 08:26:23 -05:00
|
|
|
import jax
|
2018-11-17 18:03:33 -08:00
|
|
|
from jax import api
|
2020-06-01 14:47:14 -07:00
|
|
|
from jax import core
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
from jax import dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
from jax import lax
|
|
|
|
from jax import test_util as jtu
|
|
|
|
from jax import lax_reference
|
2018-12-17 17:20:52 -08:00
|
|
|
from jax.test_util import check_grads
|
2019-12-17 21:42:37 -05:00
|
|
|
import jax.util
|
2020-08-18 10:17:38 -07:00
|
|
|
from jax.util import prod
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-12 09:00:39 -08:00
|
|
|
from jax.config import config
|
2018-12-06 18:37:59 -05:00
|
|
|
config.parse_flags_with_absl()
|
2018-11-29 12:30:34 -08:00
|
|
|
FLAGS = config.FLAGS
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
### lax tests
|
|
|
|
|
|
|
|
# For standard unops and binops, we can generate a large number of tests on
|
|
|
|
# arguments of appropriate shapes and dtypes using the following table.
|
|
|
|
|
2020-07-07 17:01:38 -07:00
|
|
|
float_dtypes = jtu.dtypes.all_floating
|
|
|
|
complex_elem_dtypes = jtu.dtypes.floating
|
|
|
|
complex_dtypes = jtu.dtypes.complex
|
|
|
|
inexact_dtypes = jtu.dtypes.all_inexact
|
2020-07-23 16:17:55 -04:00
|
|
|
int_dtypes = jtu.dtypes.all_integer
|
|
|
|
uint_dtypes = jtu.dtypes.all_unsigned
|
2020-07-07 17:01:38 -07:00
|
|
|
bool_dtypes = jtu.dtypes.boolean
|
2018-11-17 18:03:33 -08:00
|
|
|
default_dtypes = float_dtypes + int_dtypes
|
|
|
|
all_dtypes = float_dtypes + complex_dtypes + int_dtypes + bool_dtypes
|
|
|
|
|
|
|
|
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
|
|
|
|
|
2019-10-22 19:53:59 -04:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
OpRecord = collections.namedtuple(
|
|
|
|
"OpRecord", ["op", "nargs", "dtypes", "rng_factory", "tol"])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
def op_record(op, nargs, dtypes, rng_factory, tol=None):
|
|
|
|
return OpRecord(op, nargs, dtypes, rng_factory, tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
LAX_OPS = [
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("neg", 1, default_dtypes + complex_dtypes, jtu.rand_small),
|
2020-01-09 11:16:52 -05:00
|
|
|
op_record("sign", 1, default_dtypes + uint_dtypes, jtu.rand_small),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("floor", 1, float_dtypes, jtu.rand_small),
|
|
|
|
op_record("ceil", 1, float_dtypes, jtu.rand_small),
|
|
|
|
op_record("round", 1, float_dtypes, jtu.rand_default),
|
2019-12-11 16:41:24 -05:00
|
|
|
op_record("nextafter", 2, [f for f in float_dtypes if f != dtypes.bfloat16],
|
|
|
|
jtu.rand_default, tol=0),
|
2019-10-21 10:56:54 -04:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("is_finite", 1, float_dtypes, jtu.rand_small),
|
2019-10-21 10:56:54 -04:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("exp", 1, float_dtypes + complex_dtypes, jtu.rand_small),
|
2019-10-22 19:53:59 -04:00
|
|
|
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
|
|
|
|
# precision.
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("expm1", 1, float_dtypes + complex_dtypes, jtu.rand_small,
|
2020-07-14 13:03:24 -07:00
|
|
|
{np.float64: 1e-8}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("log", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
|
|
|
op_record("log1p", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
2019-10-22 19:53:59 -04:00
|
|
|
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
|
|
|
|
# ~float32 precision.
|
|
|
|
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("tanh", 1, float_dtypes + complex_dtypes, jtu.rand_small,
|
2020-07-14 13:03:24 -07:00
|
|
|
{np.float64: 1e-9, np.complex128: 1e-7}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("sin", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("cos", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("atan2", 2, float_dtypes, jtu.rand_default),
|
|
|
|
|
|
|
|
op_record("sqrt", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
|
|
|
op_record("rsqrt", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
|
|
|
op_record("square", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("reciprocal", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
2020-07-31 17:16:28 -07:00
|
|
|
op_record("tan", 1, float_dtypes + complex_dtypes, jtu.rand_default, {np.float32: 3e-5}),
|
|
|
|
op_record("asin", 1, float_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("acos", 1, float_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("atan", 1, float_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("asinh", 1, float_dtypes + complex_dtypes, jtu.rand_default,
|
|
|
|
tol={np.complex64: 1E-4, np.complex128: 1E-5}),
|
|
|
|
op_record("acosh", 1, float_dtypes + complex_dtypes, jtu.rand_positive),
|
2020-05-01 11:45:28 -04:00
|
|
|
# TODO(b/155331781): atanh has only ~float precision
|
2020-07-31 17:16:28 -07:00
|
|
|
op_record("atanh", 1, float_dtypes + complex_dtypes, jtu.rand_small, {np.float64: 1e-9}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("sinh", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("cosh", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("lgamma", 1, float_dtypes, jtu.rand_positive,
|
2020-07-14 13:03:24 -07:00
|
|
|
{np.float32: 1e-3 if jtu.device_under_test() == "tpu" else 1e-5,
|
|
|
|
np.float64: 1e-14}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("digamma", 1, float_dtypes, jtu.rand_positive,
|
2020-07-14 13:03:24 -07:00
|
|
|
{np.float64: 1e-14}),
|
2020-01-29 14:16:58 -05:00
|
|
|
op_record("betainc", 3, float_dtypes, jtu.rand_positive,
|
2020-07-14 13:03:24 -07:00
|
|
|
{np.float64: 1e-14}),
|
2020-01-29 14:16:58 -05:00
|
|
|
op_record("igamma", 2,
|
2020-07-14 13:03:24 -07:00
|
|
|
[f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
|
|
|
|
jtu.rand_positive, {np.float64: 1e-14}),
|
2020-01-29 14:16:58 -05:00
|
|
|
op_record("igammac", 2,
|
2020-07-14 13:03:24 -07:00
|
|
|
[f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
|
|
|
|
jtu.rand_positive, {np.float64: 1e-14}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("erf", 1, float_dtypes, jtu.rand_small),
|
|
|
|
op_record("erfc", 1, float_dtypes, jtu.rand_small),
|
2020-01-30 19:19:41 -05:00
|
|
|
# TODO(b/142976030): the approximation of erfinf used by XLA is only
|
|
|
|
# accurate to float32 precision.
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("erf_inv", 1, float_dtypes, jtu.rand_small,
|
2020-07-14 13:03:24 -07:00
|
|
|
{np.float64: 1e-9}),
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("bessel_i0e", 1, float_dtypes, jtu.rand_default),
|
|
|
|
op_record("bessel_i1e", 1, float_dtypes, jtu.rand_default),
|
2019-10-21 10:56:54 -04:00
|
|
|
|
2019-11-11 12:51:15 -08:00
|
|
|
op_record("real", 1, complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("imag", 1, complex_dtypes, jtu.rand_default),
|
2019-12-06 14:49:27 -05:00
|
|
|
op_record("complex", 2, complex_elem_dtypes, jtu.rand_default),
|
|
|
|
op_record("conj", 1, complex_elem_dtypes + complex_dtypes,
|
2019-11-11 12:51:15 -08:00
|
|
|
jtu.rand_default),
|
|
|
|
op_record("abs", 1, default_dtypes + complex_dtypes, jtu.rand_default),
|
|
|
|
op_record("pow", 2, float_dtypes + complex_dtypes, jtu.rand_positive),
|
|
|
|
|
|
|
|
op_record("bitwise_and", 2, bool_dtypes, jtu.rand_small),
|
|
|
|
op_record("bitwise_not", 1, bool_dtypes, jtu.rand_small),
|
|
|
|
op_record("bitwise_or", 2, bool_dtypes, jtu.rand_small),
|
|
|
|
op_record("bitwise_xor", 2, bool_dtypes, jtu.rand_small),
|
2020-07-28 19:46:00 -07:00
|
|
|
op_record("population_count", 1, int_dtypes + uint_dtypes, jtu.rand_int),
|
2019-11-11 12:51:15 -08:00
|
|
|
|
|
|
|
op_record("add", 2, default_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("sub", 2, default_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("mul", 2, default_dtypes + complex_dtypes, jtu.rand_small),
|
|
|
|
op_record("div", 2, default_dtypes + complex_dtypes, jtu.rand_nonzero),
|
|
|
|
op_record("rem", 2, default_dtypes, jtu.rand_nonzero),
|
|
|
|
|
|
|
|
op_record("max", 2, all_dtypes, jtu.rand_small),
|
|
|
|
op_record("min", 2, all_dtypes, jtu.rand_small),
|
|
|
|
|
|
|
|
op_record("eq", 2, all_dtypes, jtu.rand_some_equal),
|
|
|
|
op_record("ne", 2, all_dtypes, jtu.rand_small),
|
|
|
|
op_record("ge", 2, default_dtypes, jtu.rand_small),
|
|
|
|
op_record("gt", 2, default_dtypes, jtu.rand_small),
|
|
|
|
op_record("le", 2, default_dtypes, jtu.rand_small),
|
|
|
|
op_record("lt", 2, default_dtypes, jtu.rand_small),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class LaxTest(jtu.JaxTestCase):
|
|
|
|
"""Numerical tests for LAX operations."""
|
|
|
|
|
2018-12-11 08:54:35 -08:00
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": jtu.format_test_name_suffix(
|
2019-10-21 10:56:54 -04:00
|
|
|
rec.op, shapes, itertools.repeat(dtype)),
|
2019-11-11 12:51:15 -08:00
|
|
|
"op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes,
|
|
|
|
"dtype": dtype}
|
2018-12-11 08:54:35 -08:00
|
|
|
for shape_group in compatible_shapes
|
2020-06-29 16:22:05 -07:00
|
|
|
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
|
2018-12-11 08:54:35 -08:00
|
|
|
for dtype in rec.dtypes)
|
|
|
|
for rec in LAX_OPS))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testOp(self, op_name, rng_factory, shapes, dtype):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2019-10-21 10:56:54 -04:00
|
|
|
op = getattr(lax, op_name)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-11 08:54:35 -08:00
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": jtu.format_test_name_suffix(
|
2019-10-21 10:56:54 -04:00
|
|
|
rec.op, shapes, itertools.repeat(dtype)),
|
2019-11-11 12:51:15 -08:00
|
|
|
"op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes,
|
|
|
|
"dtype": dtype, "tol": rec.tol}
|
2018-12-11 08:54:35 -08:00
|
|
|
for shape_group in compatible_shapes
|
2020-06-29 16:22:05 -07:00
|
|
|
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
|
2018-12-11 08:54:35 -08:00
|
|
|
for dtype in rec.dtypes)
|
|
|
|
for rec in LAX_OPS))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
|
2020-03-12 15:53:47 -07:00
|
|
|
if (not FLAGS.jax_enable_x64 and op_name == "nextafter"
|
2020-07-14 13:03:24 -07:00
|
|
|
and dtype == np.float64):
|
2020-03-12 15:53:47 -07:00
|
|
|
raise SkipTest("64-bit mode disabled")
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2019-10-21 10:56:54 -04:00
|
|
|
op = getattr(lax, op_name)
|
|
|
|
numpy_op = getattr(lax_reference, op_name)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker, tol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# TODO test shift_left, shift_right_arithmetic, shift_right_logical
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
|
|
|
|
from_dtype, to_dtype),
|
2019-11-11 12:51:15 -08:00
|
|
|
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
2020-07-14 13:03:24 -07:00
|
|
|
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConvertElementType(self, from_dtype, to_dtype, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
|
|
|
op = lambda x: lax.convert_element_type(x, to_dtype)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_from_dtype={}_to_dtype={}"
|
|
|
|
.format(from_dtype, to_dtype),
|
2019-11-11 12:51:15 -08:00
|
|
|
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
2020-07-14 13:03:24 -07:00
|
|
|
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConvertElementTypeAgainstNumpy(self, from_dtype, to_dtype, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
|
|
|
op = lambda x: lax.convert_element_type(x, to_dtype)
|
|
|
|
numpy_op = lambda x: lax_reference.convert_element_type(x, to_dtype)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_from_dtype={}_to_dtype={}"
|
|
|
|
.format(from_dtype, to_dtype),
|
2019-11-11 12:51:15 -08:00
|
|
|
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
2020-07-14 13:03:24 -07:00
|
|
|
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBitcastConvertType(self, from_dtype, to_dtype, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
|
|
|
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_from_dtype={}_to_dtype={}"
|
|
|
|
.format(from_dtype, to_dtype),
|
2019-11-11 12:51:15 -08:00
|
|
|
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for from_dtype, to_dtype in itertools.product(
|
2020-07-14 13:03:24 -07:00
|
|
|
[np.float32, np.int32, "float32", "int32"], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng((2, 3), from_dtype)]
|
|
|
|
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
|
|
|
numpy_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(min_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(operand_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(max_shape, dtype)),
|
|
|
|
"min_shape": min_shape, "operand_shape": operand_shape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"max_shape": max_shape, "dtype": dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for min_shape, operand_shape, max_shape in [
|
|
|
|
[(), (2, 3), ()],
|
|
|
|
[(2, 3), (2, 3), ()],
|
|
|
|
[(), (2, 3), (2, 3)],
|
|
|
|
[(2, 3), (2, 3), (2, 3)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testClamp(self, min_shape, operand_shape, max_shape, dtype, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [min_shape, operand_shape, max_shape]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(lax.clamp, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(min_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(operand_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(max_shape, dtype)),
|
|
|
|
"min_shape": min_shape, "operand_shape": operand_shape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"max_shape": max_shape, "dtype": dtype, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for min_shape, operand_shape, max_shape in [
|
|
|
|
[(), (2, 3), ()],
|
|
|
|
[(2, 3), (2, 3), ()],
|
|
|
|
[(), (2, 3), (2, 3)],
|
|
|
|
[(2, 3), (2, 3), (2, 3)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testClampAgainstNumpy(self, min_shape, operand_shape, max_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [min_shape, operand_shape, max_shape]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(lax_reference.clamp, lax.clamp, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
2020-07-14 13:03:24 -07:00
|
|
|
dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name,
|
2018-11-17 18:03:33 -08:00
|
|
|
num_arrs),
|
|
|
|
"dim": dim, "base_shape": base_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"num_arrs": num_arrs, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for num_arrs in [3]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
for dim in range(len(base_shape))
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConcatenate(self, dim, base_shape, dtype, num_arrs, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
|
|
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
|
|
|
op = lambda *args: lax.concatenate(args, dim)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
|
2020-07-14 13:03:24 -07:00
|
|
|
dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name,
|
2018-11-17 18:03:33 -08:00
|
|
|
num_arrs),
|
|
|
|
"dim": dim, "base_shape": base_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"num_arrs": num_arrs, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for num_arrs in [3]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
|
|
|
for dim in range(len(base_shape))
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
|
|
|
|
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
|
|
|
|
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
|
|
|
|
op = lambda *args: lax.concatenate(args, dim)
|
|
|
|
numpy_op = lambda *args: lax_reference.concatenate(args, dim)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"strides": strides, "padding": padding, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i, 9, 10), (j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
for strides in [(1, 1), (1, 2), (2, 1)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testConv(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv(lhs, rhs, strides, padding)
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"strides": strides, "padding": padding, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i, 9, 10), (j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
for strides in [(1, 1), (1, 2), (2, 1)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvAgainstNumpy(self, lhs_shape, rhs_shape, dtype, strides, padding,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
op = lambda lhs, rhs: lax.conv(lhs, rhs, strides, padding)
|
|
|
|
numpy_op = lambda lhs, rhs: lax_reference.conv(lhs, rhs, strides, padding)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
|
|
|
"_lhs_dilation={}_rhs_dilation={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
strides, padding, lhs_dilation, rhs_dilation),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rhs_dilation": rhs_dilation, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i, 9, 10), (j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([1, 2, 3], repeat=3)]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for strides in [(1, 1), (1, 2), (2, 1)]
|
2018-11-17 18:03:33 -08:00
|
|
|
for padding in [((0, 0), (0, 0)), ((1, 2), (2, 0))]
|
|
|
|
for lhs_dilation, rhs_dilation in itertools.product(
|
|
|
|
[(1, 1), (1, 2), (2, 2)], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvWithGeneralPadding(self, lhs_shape, rhs_shape, dtype, strides,
|
2019-11-11 12:51:15 -08:00
|
|
|
padding, lhs_dilation, rhs_dilation, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_with_general_padding(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
|
|
|
"_lhs_dilation={}_rhs_dilation={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
strides, padding, lhs_dilation, rhs_dilation),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rhs_dilation": rhs_dilation, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i, 9, 10), (j, i, 4, 5))
|
|
|
|
for b, i, j in itertools.product([1, 2, 3], repeat=3)]
|
2020-07-14 13:03:24 -07:00
|
|
|
for dtype in [np.float32] for strides in [(1, 1), (1, 2), (2, 1)]
|
2018-11-17 18:03:33 -08:00
|
|
|
for padding in [((0, 0), (0, 0)), ((1, 2), (2, 0))]
|
|
|
|
for lhs_dilation, rhs_dilation in itertools.product(
|
|
|
|
[(1, 1), (1, 2), (2, 2)], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2020-04-06 21:45:10 -04:00
|
|
|
def testConvWithGeneralPaddingAgainstNumpy(
|
2018-11-17 18:03:33 -08:00
|
|
|
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation,
|
2019-11-11 12:51:15 -08:00
|
|
|
rhs_dilation, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_with_general_padding(
|
2020-04-06 21:45:10 -04:00
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
|
|
|
|
precision=lax.Precision.HIGHEST)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def numpy_fun(lhs, rhs):
|
|
|
|
return lax_reference.conv_with_general_padding(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
|
|
|
|
2020-04-06 21:45:10 -04:00
|
|
|
self._CheckAgainstNumpy(numpy_fun, fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
|
|
|
"_lhs_dilation={}_rhs_dilation={}"
|
|
|
|
"_dims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
strides, padding, lhs_dilation, rhs_dilation,
|
|
|
|
",".join(dim_nums)),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
|
|
|
|
"rhs_dilation": rhs_dilation, "dimension_numbers": dim_nums,
|
2020-04-09 16:21:30 -04:00
|
|
|
"feature_group_count": feature_group_count,
|
|
|
|
"batch_group_count": batch_group_count,
|
2019-11-11 12:51:15 -08:00
|
|
|
"perms": perms, "rng_factory": rng_factory}
|
2020-04-09 16:21:30 -04:00
|
|
|
for batch_group_count, feature_group_count in [(1, 1), (2, 1), (1, 2)]
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [
|
2020-04-09 16:21:30 -04:00
|
|
|
((b * batch_group_count, i * feature_group_count, 9, w),
|
|
|
|
(j * feature_group_count * batch_group_count, i, 4, 5))
|
2019-12-02 14:43:43 -05:00
|
|
|
for w in [0, 10]
|
2018-11-17 18:03:33 -08:00
|
|
|
for b, i, j in itertools.product([2, 3], repeat=3)]
|
2020-07-13 14:44:24 -04:00
|
|
|
for dtype in inexact_dtypes for strides in [(1, 1), (2, 1)]
|
2019-12-02 14:43:43 -05:00
|
|
|
for padding in [((1, 2), (2, 0)), ((10, 8), (7, 13))]
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_dilation, rhs_dilation in itertools.product(
|
2019-12-02 14:43:43 -05:00
|
|
|
[(1, 1), (1, 2), (1, 4)], repeat=2)
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]
|
2019-02-15 13:16:27 -05:00
|
|
|
for dim_nums, perms in [
|
|
|
|
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
|
|
|
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
|
|
|
|
(("NCHW", "HWIO", "NHWC"), ([0, 1, 2, 3], [2, 3, 1, 0])),
|
|
|
|
]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvGeneralDilated(self, lhs_shape, rhs_shape, dtype, strides,
|
|
|
|
padding, lhs_dilation, rhs_dilation,
|
2020-04-09 16:21:30 -04:00
|
|
|
feature_group_count, batch_group_count,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimension_numbers, perms, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs_perm, rhs_perm = perms # permute to compatible shapes
|
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
return [lax.transpose(rng(lhs_shape, dtype), lhs_perm),
|
|
|
|
lax.transpose(rng(rhs_shape, dtype), rhs_perm)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_general_dilated(
|
|
|
|
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
|
2020-04-09 16:21:30 -04:00
|
|
|
dimension_numbers, feature_group_count=feature_group_count,
|
|
|
|
batch_group_count=batch_group_count)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# TODO(mattjj): test conv_general_dilated against numpy
|
|
|
|
|
2020-01-09 14:36:37 -05:00
|
|
|
def testConv0DIsDot(self):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-01-09 14:36:37 -05:00
|
|
|
def args_maker():
|
2020-07-14 13:03:24 -07:00
|
|
|
return [rng((10, 5), np.float32), rng((5, 7), np.float32)]
|
2020-01-09 14:36:37 -05:00
|
|
|
jnp_fun = partial(lax.conv_general_dilated, window_strides=(),
|
|
|
|
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(jnp_fun, args_maker)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np.dot, jnp_fun, args_maker, tol=.1)
|
2020-01-09 14:36:37 -05:00
|
|
|
|
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
@staticmethod
|
|
|
|
def _conv_transpose_via_grad(data, kernel, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=None, dimension_numbers=None):
|
2019-04-11 09:12:21 -07:00
|
|
|
"""Helper method: calculates conv transpose via grad for testing."""
|
2019-04-09 22:59:03 -07:00
|
|
|
assert len(data.shape) == len(kernel.shape)
|
|
|
|
nspatial = len(data.shape) - 2
|
|
|
|
one = (1,) * nspatial
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation = rhs_dilation or one
|
2019-04-09 22:59:03 -07:00
|
|
|
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
|
|
|
dimension_numbers)
|
2020-07-14 13:03:24 -07:00
|
|
|
in_shape = np.take(data.shape, dn.lhs_spec)
|
2019-04-09 22:59:03 -07:00
|
|
|
in_sdims = in_shape[2:]
|
2020-07-14 13:03:24 -07:00
|
|
|
k_shape = np.take(kernel.shape, dn.rhs_spec)
|
2019-04-09 22:59:03 -07:00
|
|
|
k_sdims = k_shape[2:]
|
2019-12-17 02:03:17 +00:00
|
|
|
e_k_sdims = [(k-1) * r + 1 for k, r in zip(k_sdims, rhs_dilation)]
|
2019-04-09 22:59:03 -07:00
|
|
|
if padding == 'VALID':
|
2019-12-17 02:03:17 +00:00
|
|
|
o_sdims = [in_sdims[i]*strides[i] + max(e_k_sdims[i]-strides[i],0)
|
2019-04-09 22:59:03 -07:00
|
|
|
for i in range(nspatial)]
|
|
|
|
elif padding == 'SAME':
|
|
|
|
o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)]
|
|
|
|
o_shape = [in_shape[0], k_shape[1]] + o_sdims
|
|
|
|
out_spec_inv = [x[0] for x in
|
|
|
|
sorted(enumerate(dn.out_spec), key=lambda x: x[1])]
|
2020-07-14 13:03:24 -07:00
|
|
|
o_layout = np.take(np.array(o_shape), out_spec_inv)
|
|
|
|
placeholder = np.ones(o_layout, data.dtype)
|
2019-04-09 22:59:03 -07:00
|
|
|
conv = lambda x: lax.conv_general_dilated(x, kernel, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
one, rhs_dilation, dn)
|
2019-04-09 22:59:03 -07:00
|
|
|
_, g = api.vjp(conv, placeholder)
|
|
|
|
return g(data)[0]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _transpose_conv_kernel(data, kernel, dimension_numbers):
|
|
|
|
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
|
|
|
dimension_numbers)
|
2020-07-14 13:03:24 -07:00
|
|
|
spatial_axes = np.array(dn.rhs_spec)[2:]
|
2019-04-09 22:59:03 -07:00
|
|
|
for axis in spatial_axes:
|
2020-07-14 13:03:24 -07:00
|
|
|
kernel = np.flip(kernel, axis)
|
|
|
|
kernel = np.swapaxes(kernel, dn.rhs_spec[0], dn.rhs_spec[1])
|
2019-04-09 22:59:03 -07:00
|
|
|
return kernel
|
|
|
|
|
2019-04-09 15:06:46 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
2019-12-17 02:03:17 +00:00
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_rhs_dilation={}".format(
|
2019-04-09 15:06:46 -07:00
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
2019-12-17 02:03:17 +00:00
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, rhs_dilation),
|
2019-04-09 15:06:46 -07:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-12-17 02:03:17 +00:00
|
|
|
"strides": strides, "padding": padding, "rhs_dilation": rhs_dilation,
|
|
|
|
"rng_factory": rng_factory, 'dspec': dspec}
|
2019-04-09 15:06:46 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
2019-04-09 22:59:03 -07:00
|
|
|
((b, 9, 10, i), (k, k, j, i)) # NB: i,j flipped in RHS for transpose
|
|
|
|
for b, i, j, k in itertools.product([2,3],[2,3],[2,3],[3,4,5])]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2019-04-09 15:06:46 -07:00
|
|
|
for strides in [(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
|
|
|
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
|
2019-12-17 02:03:17 +00:00
|
|
|
for rhs_dilation in [None, (2, 2)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-04-09 22:59:03 -07:00
|
|
|
def testConvTranspose2DT(self, lhs_shape, rhs_shape, dtype, strides,
|
2019-12-17 02:03:17 +00:00
|
|
|
padding, dspec, rhs_dilation, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2019-04-09 22:59:03 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
# NB: this test calculates conv_transpose performing identically to the
|
|
|
|
# lhs-grad of conv.
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec,
|
|
|
|
transpose_kernel=True)
|
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec)
|
|
|
|
|
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
2019-12-17 02:03:17 +00:00
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_rhs_dilation={}".format(
|
2019-04-09 22:59:03 -07:00
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
2019-12-17 02:03:17 +00:00
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, rhs_dilation),
|
2019-04-09 22:59:03 -07:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-12-17 02:03:17 +00:00
|
|
|
"strides": strides, "padding": padding, "rhs_dilation": rhs_dilation,
|
|
|
|
"rng_factory": rng_factory, 'dspec': dspec}
|
2019-04-09 22:59:03 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, 9, 10, i), (k, k, i, j))
|
|
|
|
for b, i, j, k in itertools.product([2,3],[2,3],[2,3],[3,4,5])]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2019-04-09 22:59:03 -07:00
|
|
|
for strides in [(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
|
|
|
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
|
2019-12-17 02:03:17 +00:00
|
|
|
for rhs_dilation in [None, (2, 2)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-04-09 22:59:03 -07:00
|
|
|
def testConvTranspose2D(self, lhs_shape, rhs_shape, dtype, strides,
|
2019-12-17 02:03:17 +00:00
|
|
|
padding, dspec, rhs_dilation, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2019-04-09 22:59:03 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec,
|
|
|
|
transpose_kernel=False)
|
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
|
|
|
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec)
|
|
|
|
|
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
|
2019-04-09 22:59:03 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
2019-12-17 02:03:17 +00:00
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_rhs_dilation={}".format(
|
2019-04-09 22:59:03 -07:00
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
2019-12-17 02:03:17 +00:00
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, rhs_dilation),
|
2019-04-09 22:59:03 -07:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-12-17 02:03:17 +00:00
|
|
|
"strides": strides, "padding": padding, "rhs_dilation": rhs_dilation,
|
|
|
|
"rng_factory": rng_factory, 'dspec': dspec}
|
2019-04-09 22:59:03 -07:00
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, 10, i), (k, i, j))
|
|
|
|
for b, i, j, k in itertools.product([2,3],[2,3],[2,3],[3,4,5])]
|
2019-06-27 17:17:04 -04:00
|
|
|
for dtype in float_dtypes
|
2019-04-09 22:59:03 -07:00
|
|
|
for strides in [(1,), (2,), (3,)]
|
|
|
|
for padding in ["VALID", "SAME"]
|
|
|
|
for dspec in [('NHC', 'HIO', 'NHC'),]
|
2019-12-17 02:03:17 +00:00
|
|
|
for rhs_dilation in [None, (2,)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2019-04-09 22:59:03 -07:00
|
|
|
def testConvTranspose1D(self, lhs_shape, rhs_shape, dtype, strides,
|
2019-12-17 02:03:17 +00:00
|
|
|
padding, dspec, rhs_dilation, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2019-04-09 15:06:46 -07:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
transpose_kernel=False)
|
2020-07-02 14:38:35 -07:00
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
|
|
|
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
|
|
|
|
rhs_dilation=rhs_dilation,
|
|
|
|
dimension_numbers=dspec)
|
|
|
|
|
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
|
2020-07-02 14:38:35 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_rhs_dilation={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, rhs_dilation),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"strides": strides, "padding": padding, "rhs_dilation": rhs_dilation,
|
|
|
|
"rng_factory": rng_factory, 'dspec': dspec}
|
|
|
|
for lhs_shape, rhs_shape in [
|
|
|
|
((b, i), (i, j))
|
|
|
|
for b, i, j in itertools.product([2,3],[2,3],[2,3])]
|
|
|
|
for dtype in float_dtypes
|
|
|
|
for strides in [()]
|
|
|
|
for padding in ["VALID", "SAME"]
|
|
|
|
for dspec in [('NC', 'IO', 'NC'),]
|
|
|
|
for rhs_dilation in [None, ()]
|
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testConvTranspose0D(self, lhs_shape, rhs_shape, dtype, strides,
|
|
|
|
padding, dspec, rhs_dilation, rng_factory):
|
|
|
|
rng = rng_factory(self.rng())
|
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.conv_transpose(lhs, rhs, strides, padding,
|
|
|
|
dimension_numbers=dspec,
|
|
|
|
rhs_dilation=rhs_dilation,
|
|
|
|
transpose_kernel=False)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
|
|
|
def fun_via_grad(lhs, rhs):
|
2019-04-09 22:59:03 -07:00
|
|
|
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
|
|
|
|
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation=rhs_dilation,
|
2019-04-09 22:59:03 -07:00
|
|
|
dimension_numbers=dspec)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
# NB: below just checks for agreement, we're not calling numpy.
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2019-06-28 09:00:32 -04:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}_precision={}".format(
|
2018-11-17 18:03:33 -08:00
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
2019-06-28 09:00:32 -04:00
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
precision),
|
2018-11-17 18:03:33 -08:00
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"precision": precision, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-06-28 12:48:44 -04:00
|
|
|
for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH,
|
|
|
|
lax.Precision.HIGHEST]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testDot(self, lhs_shape, rhs_shape, dtype, precision, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype)),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testDotAgainstNumpy(self, lhs_shape, rhs_shape, dtype, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
2019-11-20 22:43:46 -05:00
|
|
|
tol = {
|
2020-07-14 13:03:24 -07:00
|
|
|
np.float16: 1e-2,
|
|
|
|
np.float64: max(jtu.default_tolerance()[np.dtype(np.float64)], 1e-14),
|
|
|
|
np.complex128: max(jtu.default_tolerance()[np.dtype(np.complex128)],
|
2019-12-16 20:48:19 -05:00
|
|
|
1e-14)
|
2019-11-20 22:43:46 -05:00
|
|
|
}
|
2019-11-16 13:51:42 -05:00
|
|
|
lax_op = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(lax_reference.dot, lax_op, args_maker, tol=tol)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
lhs_contracting, rhs_contracting),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
|
2020-07-16 16:23:27 -04:00
|
|
|
[(5,), (5,), [0], [0]],
|
|
|
|
[(5, 7), (5,), [0], [0]],
|
|
|
|
[(7, 5), (5,), [1], [0]],
|
2019-06-17 20:44:33 -07:00
|
|
|
[(3, 5), (2, 5), [1], [1]],
|
|
|
|
[(5, 3), (5, 2), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
|
2020-07-16 16:23:27 -04:00
|
|
|
[(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]],
|
2019-06-17 20:44:33 -07:00
|
|
|
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
|
2018-11-17 18:03:33 -08:00
|
|
|
[(3, 2), (2, 4), [1], [0]],
|
|
|
|
]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
lhs_contracting, rhs_contracting, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.dot_general(lhs, rhs, dimension_numbers)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=False)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimension_numbers": dimension_numbers, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape, dimension_numbers in [
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
2020-07-16 16:23:27 -04:00
|
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
|
2018-11-17 18:03:33 -08:00
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
|
|
|
]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimension_numbers, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
|
|
|
|
def fun(lhs, rhs):
|
|
|
|
return lax.dot_general(lhs, rhs, dimension_numbers)
|
|
|
|
|
|
|
|
self._CompileAndCheck(fun, args_maker, check_dtypes=False)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimension_numbers": dimension_numbers, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape, dimension_numbers in [
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
|
2020-07-16 16:23:27 -04:00
|
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
|
2018-11-17 18:03:33 -08:00
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
|
|
|
]
|
2019-12-16 20:48:19 -05:00
|
|
|
for dtype in all_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimension_numbers, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
|
|
|
op = lambda x, y: lax.dot_general(x, y, dimension_numbers)
|
|
|
|
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
|
2020-07-14 13:03:24 -07:00
|
|
|
shape, np.dtype(dtype).name, broadcast_sizes),
|
2018-11-17 18:03:33 -08:00
|
|
|
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(), (2, 3)]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for broadcast_sizes in [(), (2,), (1, 2)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcast(self, shape, dtype, broadcast_sizes, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast(x, broadcast_sizes)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_broadcast_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), broadcast_sizes),
|
|
|
|
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(), (2, 3)]
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for broadcast_sizes in [(), (2,), (1, 2)]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcastAgainstNumpy(self, shape, dtype, broadcast_sizes, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast(x, broadcast_sizes)
|
|
|
|
numpy_op = lambda x: lax_reference.broadcast(x, broadcast_sizes)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(inshape, dtype),
|
|
|
|
outshape, broadcast_dimensions),
|
|
|
|
"inshape": inshape, "dtype": dtype, "outshape": outshape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimensions": broadcast_dimensions, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for inshape, outshape, broadcast_dimensions in [
|
|
|
|
([2], [2, 2], [0]),
|
|
|
|
([2], [2, 2], [1]),
|
|
|
|
([2], [2, 3], [0]),
|
|
|
|
([], [2, 3], []),
|
2020-03-16 09:54:58 +01:00
|
|
|
([1], [2, 3], [1]),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(inshape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-16 09:54:58 +01:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
2020-07-14 13:03:24 -07:00
|
|
|
jtu.format_shape_dtype_string(inshape, np.float32),
|
2020-03-16 09:54:58 +01:00
|
|
|
outshape, broadcast_dimensions),
|
|
|
|
"inshape": inshape, "outshape": outshape,
|
|
|
|
"broadcast_dimensions": broadcast_dimensions, "err_msg": err_msg}
|
|
|
|
for inshape, outshape, broadcast_dimensions, err_msg in [
|
|
|
|
([2], [2, 2], [0, 1], ('broadcast_dimensions must have length equal to '
|
|
|
|
'operand ndim')),
|
|
|
|
([2, 2], [2], [0, 1], ('target broadcast shape must have equal or higher rank '
|
|
|
|
'to the operand shape')),
|
|
|
|
([2], [2, 3], [2], ('broadcast_in_dim broadcast_dimensions must be a subset of output '
|
|
|
|
'dimensions')),
|
|
|
|
([2], [3], [0], ('operand dimension sizes must either be 1, or be '
|
|
|
|
'equal to their corresponding dimensions in the target broadcast shape')),
|
|
|
|
([2, 2], [2, 2], [1, 0], ('broadcast_dimensions must be strictly increasing')),
|
|
|
|
]))
|
|
|
|
def testBroadcastInDimShapeCheck(self, inshape, outshape, broadcast_dimensions, err_msg):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
x = rng(inshape, np.float32)
|
2020-03-16 09:54:58 +01:00
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
|
|
|
lax.broadcast_in_dim(x, shape=outshape, broadcast_dimensions=broadcast_dimensions)
|
|
|
|
|
2020-03-15 20:29:11 -07:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
|
|
|
|
jtu.format_shape_dtype_string(inshape, dtype),
|
|
|
|
outshape, broadcast_dimensions),
|
|
|
|
"inshape": inshape, "dtype": dtype, "outshape": outshape,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dimensions": broadcast_dimensions, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for inshape, outshape, broadcast_dimensions in [
|
|
|
|
([2], [2, 2], [0]),
|
|
|
|
([2], [2, 2], [1]),
|
|
|
|
([2], [2, 3], [0]),
|
|
|
|
([], [2, 3], []),
|
2020-03-16 09:54:58 +01:00
|
|
|
([1], [2, 3], [1]),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testBroadcastInDimAgainstNumpy(self, inshape, dtype, outshape,
|
2019-11-11 12:51:15 -08:00
|
|
|
dimensions, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(inshape, dtype)]
|
|
|
|
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
|
|
|
numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_inshape={}_dimensions={}".format(
|
2020-07-14 13:03:24 -07:00
|
|
|
jtu.format_shape_dtype_string(inshape, np.float32), dimensions),
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
"inshape": inshape, "dimensions": dimensions, "error_type": error_type,
|
|
|
|
"err_msg": err_msg}
|
|
|
|
for inshape, dimensions, error_type, err_msg in [
|
|
|
|
((1, 2, 3), (0, 0), ValueError, 'dimensions are not unique'),
|
|
|
|
((1, 2, 3), (3,), ValueError, 'axis 3 is out of bounds'),
|
|
|
|
((1, 2, 3), (-4,), ValueError, 'axis -4 is out of bounds'),
|
|
|
|
((1, 2, 3), (1,), ValueError, 'cannot select an axis to squeeze out'),
|
|
|
|
((1, 2, 3), (None,), TypeError, 'cannot be interpreted as an integer'),
|
|
|
|
]))
|
|
|
|
def testSqueezeShapeCheck(self, inshape, dimensions, error_type, err_msg):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
x = rng(inshape, np.float32)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
with self.assertRaisesRegex(error_type, err_msg):
|
|
|
|
lax.squeeze(x, dimensions=dimensions)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_inshape={}_dimensions={}".format(
|
2020-07-14 13:03:24 -07:00
|
|
|
jtu.format_shape_dtype_string(arg_shape, np.float32), dimensions),
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
"arg_shape": arg_shape, "dimensions": dimensions,
|
|
|
|
"rng_factory": rng_factory}
|
|
|
|
for arg_shape, dimensions in [
|
|
|
|
[(1,), (0,)],
|
|
|
|
[(1,), (-1,)],
|
|
|
|
[(2, 1, 4), (1,)],
|
|
|
|
[(2, 1, 3, 1), (1,)],
|
|
|
|
[(2, 1, 3, 1), (1, 3)],
|
|
|
|
[(2, 1, 3, 1), (3,)],
|
|
|
|
]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSqueeze(self, arg_shape, dimensions, rng_factory):
|
|
|
|
rng = rng_factory(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
args_maker = lambda: [rng(arg_shape, np.float32)]
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
op = lambda x: lax.squeeze(x, dimensions)
|
|
|
|
numpy_op = lambda x: lax_reference.squeeze(x, dimensions)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
check_grads(op, args_maker(), 2, ["fwd", "rev"], eps=1.)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(out_shape, dtype)),
|
|
|
|
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in default_dtypes
|
|
|
|
for arg_shape, out_shape in [
|
|
|
|
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testReshape(self, arg_shape, out_shape, dtype, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
|
|
op = lambda x: lax.reshape(x, out_shape)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_outshape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(out_shape, dtype)),
|
|
|
|
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in default_dtypes
|
|
|
|
for arg_shape, out_shape in [
|
|
|
|
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testReshapeAgainstNumpy(self, arg_shape, out_shape, dtype, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(arg_shape, dtype)]
|
|
|
|
op = lambda x: lax.reshape(x, out_shape)
|
|
|
|
numpy_op = lambda x: lax_reference.reshape(x, out_shape)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_pads={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
|
2020-06-30 12:07:38 -04:00
|
|
|
for shape in [(0, 2), (2, 3)]
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in default_dtypes
|
2018-12-06 18:37:59 -05:00
|
|
|
for pads in [[(1, 2, 1), (0, 1, 0)]]))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testPad(self, shape, dtype, pads, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-07-14 13:03:24 -07:00
|
|
|
fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_inshape={}_pads={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(2, 3)]
|
|
|
|
for dtype in default_dtypes
|
2020-06-17 11:57:21 +03:00
|
|
|
for pads in [
|
|
|
|
[(0, 0, 0), (0, 0, 0)], # no padding
|
|
|
|
[(1, 1, 0), (2, 2, 0)], # only positive edge padding
|
|
|
|
[(1, 2, 1), (0, 1, 0)], # edge padding and interior padding
|
|
|
|
[(0, 0, 0), (-1, -1, 0)], # negative padding
|
|
|
|
[(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges
|
|
|
|
[(0, 0, 0), (-2, -3, 1)], # remove everything in one dimension
|
|
|
|
]))
|
2019-11-11 12:51:15 -08:00
|
|
|
def testPadAgainstNumpy(self, shape, dtype, pads, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-07-14 13:03:24 -07:00
|
|
|
op = lambda x: lax.pad(x, np.array(0, dtype), pads)
|
|
|
|
numpy_op = lambda x: lax_reference.pad(x, np.array(0, dtype), pads)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def testReverse(self):
|
|
|
|
rev = api.jit(lambda operand: lax.rev(operand, dimensions))
|
|
|
|
|
2020-01-27 16:14:28 -08:00
|
|
|
dimensions = []
|
2020-07-14 13:03:24 -07:00
|
|
|
self.assertAllClose(np.array([0, 1, 2, 3]), rev(np.array([0, 1, 2, 3])),
|
2020-01-27 16:14:28 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
dimensions = [0]
|
2020-07-14 13:03:24 -07:00
|
|
|
self.assertAllClose(np.array([3, 2, 1]), rev(np.array([1, 2, 3])),
|
2018-11-17 18:03:33 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
dimensions = [0, 1]
|
2020-07-14 13:03:24 -07:00
|
|
|
self.assertAllClose(np.array([[6, 5, 4], [3, 2, 1]]),
|
|
|
|
rev(np.array([[1, 2, 3], [4, 5, 6]])),
|
2018-11-17 18:03:33 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_predshape={}_argshapes={}".format(
|
2020-07-14 13:03:24 -07:00
|
|
|
jtu.format_shape_dtype_string(pred_shape, np.bool_),
|
2018-11-17 18:03:33 -08:00
|
|
|
jtu.format_shape_dtype_string(arg_shape, arg_dtype)),
|
|
|
|
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape in [(), (3,), (2, 3)]
|
|
|
|
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
|
|
|
for arg_dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSelect(self, pred_shape, arg_shape, arg_dtype, rng_factory):
|
2018-11-17 18:03:33 -08:00
|
|
|
def args_maker():
|
2020-07-14 13:03:24 -07:00
|
|
|
return [rng(pred_shape, np.bool_), rng(arg_shape, arg_dtype),
|
2018-11-17 18:03:33 -08:00
|
|
|
rng(arg_shape, arg_dtype)]
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2020-06-01 17:19:23 -04:00
|
|
|
return self._CompileAndCheck(lax.select, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_predshape={}_argshapes={}".format(
|
2020-07-14 13:03:24 -07:00
|
|
|
jtu.format_shape_dtype_string(pred_shape, np.bool_),
|
2018-11-17 18:03:33 -08:00
|
|
|
jtu.format_shape_dtype_string(arg_shape, arg_dtype)),
|
|
|
|
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for arg_shape in [(), (3,), (2, 3)]
|
|
|
|
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
|
|
|
for arg_dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSelectAgainstNumpy(self, pred_shape, arg_shape, arg_dtype, rng_factory):
|
2018-11-17 18:03:33 -08:00
|
|
|
def args_maker():
|
2020-07-14 13:03:24 -07:00
|
|
|
return [rng(pred_shape, np.bool_), rng(arg_shape, arg_dtype),
|
2018-11-17 18:03:33 -08:00
|
|
|
rng(arg_shape, arg_dtype)]
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2020-08-03 17:17:48 +02:00
|
|
|
return self._CheckAgainstNumpy(lax_reference.select, lax.select, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, limit_indices, strides),
|
|
|
|
"shape": shape, "dtype": dtype, "starts": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"limits": limit_indices, "strides": strides, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, limit_indices, strides in [
|
|
|
|
[(3,), (1,), (2,), None],
|
|
|
|
[(7,), (4,), (7,), None],
|
|
|
|
[(5,), (1,), (5,), (2,)],
|
|
|
|
[(8,), (1,), (6,), (2,)],
|
|
|
|
[(5, 3), (1, 1), (3, 2), None],
|
|
|
|
[(5, 3), (1, 1), (3, 1), None],
|
|
|
|
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
|
|
|
|
[(5, 3), (1, 1), (2, 1), (1, 1)],
|
|
|
|
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testSlice(self, shape, dtype, starts, limits, strides, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.slice(x, starts, limits, strides)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, limit_indices, strides),
|
|
|
|
"shape": shape, "dtype": dtype, "starts": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"limits": limit_indices, "strides": strides, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, limit_indices, strides in [
|
|
|
|
[(3,), (1,), (2,), None],
|
|
|
|
[(7,), (4,), (7,), None],
|
|
|
|
[(5,), (1,), (5,), (2,)],
|
|
|
|
[(8,), (1,), (6,), (2,)],
|
|
|
|
[(5, 3), (1, 1), (3, 2), None],
|
|
|
|
[(5, 3), (1, 1), (3, 1), None],
|
|
|
|
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
|
|
|
|
[(5, 3), (1, 1), (2, 1), (1, 1)],
|
|
|
|
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testSliceAgainstNumpy(self, shape, dtype, starts, limits,
|
2019-11-11 12:51:15 -08:00
|
|
|
strides, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.slice(x, starts, limits, strides)
|
|
|
|
numpy_op = lambda x: lax_reference.slice(x, starts, limits, strides)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, size_indices),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"size_indices": size_indices, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, size_indices in [
|
2020-07-14 13:03:24 -07:00
|
|
|
[(3,), np.array((1,)), (1,)],
|
2018-11-17 18:03:33 -08:00
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
2020-07-14 13:03:24 -07:00
|
|
|
[(5, 3), np.array((1, 1)), (3, 1)],
|
|
|
|
[(7, 5, 3), np.array((4, 1, 0)), (2, 0, 1)],
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testDynamicSlice(self, shape, dtype, start_indices, size_indices, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype), np.array(start_indices)]
|
2018-11-17 18:03:33 -08:00
|
|
|
op = lambda x, starts: lax.dynamic_slice(x, starts, size_indices)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, size_indices),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"size_indices": size_indices, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, size_indices in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDynamicSliceAgainstNumpy(self, shape, dtype, start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
size_indices, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype), np.array(start_indices)]
|
2018-11-17 18:03:33 -08:00
|
|
|
op = lambda x, s: lax.dynamic_slice(x, s, size_indices)
|
|
|
|
numpy_op = lambda x, s: lax_reference.dynamic_slice(x, s, size_indices)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-12-20 13:29:53 -05:00
|
|
|
def testDynamicSliceInDim(self):
|
|
|
|
# Regression test for mixed type problem in dynamic_slice_in_dim.
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
x = rng((6, 7), np.int32)
|
|
|
|
np.testing.assert_equal(lax.dynamic_slice_in_dim(x, 2, 3), x[2:5])
|
2019-12-20 13:29:53 -05:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, update_shape),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, update_shape in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDynamicUpdateSlice(self, shape, dtype, start_indices, update_shape,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
return [rng(shape, dtype), rng(update_shape, dtype),
|
2020-07-14 13:03:24 -07:00
|
|
|
np.array(start_indices)]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(lax.dynamic_update_slice, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
start_indices, update_shape),
|
|
|
|
"shape": shape, "dtype": dtype, "start_indices": start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, start_indices, update_shape in [
|
|
|
|
[(3,), (1,), (1,)],
|
|
|
|
[(5, 3), (1, 1), (3, 1)],
|
|
|
|
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2018-11-28 17:17:52 -08:00
|
|
|
def testDynamicUpdateSliceAgainstNumpy(self, shape, dtype, start_indices,
|
2019-11-11 12:51:15 -08:00
|
|
|
update_shape, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def args_maker():
|
|
|
|
return [rng(shape, dtype), rng(update_shape, dtype),
|
2020-07-14 13:03:24 -07:00
|
|
|
np.array(start_indices)]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(lax_reference.dynamic_update_slice,
|
|
|
|
lax.dynamic_update_slice, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_perm={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), perm),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, perm in [
|
|
|
|
[(3, 4), (1, 0)],
|
|
|
|
[(3, 4), (0, 1)],
|
|
|
|
[(3, 4, 5), (2, 1, 0)],
|
|
|
|
[(3, 4, 5), (1, 0, 2)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testTranspose(self, shape, dtype, perm, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.transpose(x, perm)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_perm={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), perm),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, perm in [
|
|
|
|
[(3, 4), (1, 0)],
|
|
|
|
[(3, 4), (0, 1)],
|
|
|
|
[(3, 4, 5), (2, 1, 0)],
|
|
|
|
[(3, 4, 5), (1, 0, 2)],
|
|
|
|
]
|
|
|
|
for dtype in default_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testTransposeAgainstNumpy(self, shape, dtype, perm, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
op = lambda x: lax.transpose(x, perm)
|
|
|
|
numpy_op = lambda x: lax_reference.transpose(x, perm)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2019-07-06 12:17:00 -07:00
|
|
|
{"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}"
|
|
|
|
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
|
|
|
|
init_val),
|
2018-11-17 18:03:33 -08:00
|
|
|
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"dims": dims, "rng_factory": rng_factory}
|
2019-11-15 10:02:51 -05:00
|
|
|
for init_val, op, types in [
|
2018-11-17 18:03:33 -08:00
|
|
|
(0, lax.add, default_dtypes),
|
2019-06-17 11:49:54 -07:00
|
|
|
(1, lax.mul, default_dtypes),
|
2019-08-03 21:27:06 -07:00
|
|
|
(0, lax.max, all_dtypes), # non-monoidal
|
2020-07-14 13:03:24 -07:00
|
|
|
(-np.inf, lax.max, float_dtypes),
|
|
|
|
(dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
|
|
|
|
# (dtypes.iinfo(np.int64).min, lax.max, [np.int64]), # TODO fails
|
|
|
|
(dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
|
|
|
|
(dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
|
|
|
|
(np.inf, lax.min, float_dtypes),
|
|
|
|
(dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
|
|
|
|
# (dtypes.iinfo(np.int64).max, lax.min, [np.int64]), # TODO fails
|
|
|
|
(dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
|
|
|
|
(dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
2019-11-15 10:02:51 -05:00
|
|
|
for dtype in types
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape, dims in [
|
|
|
|
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
|
|
|
|
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
|
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [
|
2020-07-14 13:03:24 -07:00
|
|
|
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
2019-11-11 12:51:15 -08:00
|
|
|
else jtu.rand_small]))
|
|
|
|
def testReduce(self, op, init_val, shape, dtype, dims, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2020-07-14 13:03:24 -07:00
|
|
|
init_val = np.asarray(init_val, dtype=dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), init_val]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# we separately test the version that uses a concrete init_val because it
|
|
|
|
# can hit different code paths
|
|
|
|
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2020-07-20 17:27:24 -04:00
|
|
|
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
|
|
|
|
"_basedilation={}_windowdilation={}")
|
|
|
|
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
dims, strides, padding, base_dilation, window_dilation),
|
|
|
|
"op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
|
|
|
|
"dims": dims, "strides": strides, "padding": padding,
|
|
|
|
"base_dilation": base_dilation, "window_dilation": window_dilation}
|
2018-11-17 18:03:33 -08:00
|
|
|
for init_val, op, dtypes in [
|
2020-07-14 13:03:24 -07:00
|
|
|
(0, lax.add, [np.float32]),
|
|
|
|
(-np.inf, lax.max, [np.float32]),
|
|
|
|
(np.inf, lax.min, [np.float32]),
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
2020-07-20 17:27:24 -04:00
|
|
|
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
|
|
|
itertools.chain(
|
|
|
|
itertools.product(
|
2019-02-03 21:10:03 -05:00
|
|
|
[(4, 6)],
|
|
|
|
[(2, 1), (1, 2)],
|
2020-07-17 16:05:51 -04:00
|
|
|
[(1, 1), (2, 1), (1, 2)],
|
2020-07-20 17:27:24 -04:00
|
|
|
["VALID", "SAME", [(0, 3), (1, 2)]],
|
|
|
|
[(1, 1), (2, 3)],
|
|
|
|
[(1, 1), (1, 2)]),
|
|
|
|
itertools.product(
|
2019-02-03 21:10:03 -05:00
|
|
|
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
2020-07-17 16:05:51 -04:00
|
|
|
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
2020-07-20 17:27:24 -04:00
|
|
|
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
|
|
|
[(1, 1, 1, 1), (2, 1, 3, 2)],
|
|
|
|
[(1, 1, 1, 1), (1, 2, 2, 1)])))
|
|
|
|
for dtype in dtypes))
|
|
|
|
def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
init_val = np.asarray(init_val, dtype=dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def fun(operand, init_val):
|
2020-07-20 17:27:24 -04:00
|
|
|
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
|
|
|
|
base_dilation, window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-17 16:05:51 -04:00
|
|
|
def reference_fun(operand, init_val):
|
|
|
|
return lax_reference.reduce_window(operand, init_val, op, dims, strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation)
|
2020-07-17 16:05:51 -04:00
|
|
|
|
2020-07-20 17:27:24 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype), init_val]
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
|
|
|
if all(d == 1 for d in window_dilation):
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# we separately test the version that uses a concrete init_val because it
|
|
|
|
# can hit different code paths
|
|
|
|
def fun(operand):
|
2020-07-20 17:27:24 -04:00
|
|
|
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
|
|
|
|
base_dilation, window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-20 17:27:24 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-06 11:22:01 -04:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_op={}_shape={}_axis={}"
|
|
|
|
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
|
2020-07-14 13:03:24 -07:00
|
|
|
"op": op, "np_op": np_op, "shape": shape, "dtype": dtype,
|
2020-04-06 11:22:01 -04:00
|
|
|
"axis": axis, "rng_factory": rng_factory}
|
2020-07-14 13:03:24 -07:00
|
|
|
for op, np_op, types in [
|
|
|
|
(lax.cumsum, np.cumsum, default_dtypes),
|
|
|
|
(lax.cumprod, np.cumprod, default_dtypes),
|
|
|
|
(lax.cummax, np.maximum.accumulate, default_dtypes),
|
|
|
|
(lax.cummin, np.minimum.accumulate, default_dtypes),
|
2020-04-06 11:22:01 -04:00
|
|
|
]
|
|
|
|
for dtype in types
|
|
|
|
for shape in [[10], [3, 4, 5]]
|
|
|
|
for axis in range(len(shape))
|
|
|
|
for rng_factory in [
|
2020-07-14 13:03:24 -07:00
|
|
|
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
2020-04-06 11:22:01 -04:00
|
|
|
else jtu.rand_small]))
|
2020-07-14 13:03:24 -07:00
|
|
|
def testCumulativeReduce(self, op, np_op, shape, dtype, axis, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2020-04-06 11:22:01 -04:00
|
|
|
fun = partial(op, axis=axis)
|
2020-07-14 13:03:24 -07:00
|
|
|
np_fun = partial(np_op, axis=axis, dtype=dtype)
|
2020-04-06 11:22:01 -04:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(np_fun, fun, args_maker)
|
2020-04-06 11:22:01 -04:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2020-06-26 18:40:00 +01:00
|
|
|
{"testcase_name": "_shape={}_axis={}_isstable={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
|
|
|
|
"shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable}
|
2020-05-14 19:17:44 -04:00
|
|
|
for dtype in all_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(5,), (5, 7)]
|
2020-06-26 18:40:00 +01:00
|
|
|
for axis in [-1, len(shape) - 1]
|
|
|
|
for is_stable in [False, True]))
|
|
|
|
def testSort(self, shape, dtype, axis, is_stable):
|
2020-05-14 19:17:44 -04:00
|
|
|
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
2020-07-14 13:03:24 -07:00
|
|
|
if (np.issubdtype(dtype, np.complexfloating) and (
|
2020-05-14 19:17:44 -04:00
|
|
|
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
|
|
|
jtu.device_under_test() == "tpu")):
|
|
|
|
raise SkipTest("Complex-valued sort not implemented")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-26 18:40:00 +01:00
|
|
|
fun = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2020-06-26 18:40:00 +01:00
|
|
|
{"testcase_name": "_shape={}_axis={}_isstable={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
|
|
|
|
"shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable}
|
2020-05-14 19:17:44 -04:00
|
|
|
for dtype in all_dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(5,), (5, 7)]
|
2020-06-26 18:40:00 +01:00
|
|
|
for axis in [-1, len(shape) - 1]
|
|
|
|
for is_stable in [False, True]))
|
|
|
|
def testSortAgainstNumpy(self, shape, dtype, axis, is_stable):
|
2020-05-14 19:17:44 -04:00
|
|
|
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
2020-07-14 13:03:24 -07:00
|
|
|
if (np.issubdtype(dtype, np.complexfloating) and (
|
2020-05-14 19:17:44 -04:00
|
|
|
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
|
|
|
jtu.device_under_test() == "tpu")):
|
|
|
|
raise SkipTest("Complex-valued sort not implemented")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2020-06-26 18:40:00 +01:00
|
|
|
op = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
|
|
|
|
def numpy_op(x):
|
|
|
|
if is_stable:
|
|
|
|
return lax_reference.sort(x, axis, kind='stable')
|
|
|
|
else:
|
|
|
|
return lax_reference.sort(x, axis)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2020-06-26 18:40:00 +01:00
|
|
|
{"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format(
|
2018-11-17 18:03:33 -08:00
|
|
|
jtu.format_shape_dtype_string(shape, key_dtype),
|
|
|
|
jtu.format_shape_dtype_string(shape, val_dtype),
|
2020-06-26 18:40:00 +01:00
|
|
|
axis, is_stable),
|
2020-05-14 19:17:44 -04:00
|
|
|
"shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype,
|
2020-06-26 18:40:00 +01:00
|
|
|
"axis": axis, "is_stable": is_stable}
|
2020-05-14 19:17:44 -04:00
|
|
|
for key_dtype in float_dtypes + complex_dtypes + int_dtypes + uint_dtypes
|
2020-07-14 13:03:24 -07:00
|
|
|
for val_dtype in [np.float32, np.int32, np.uint32]
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(3,), (5, 3)]
|
2020-06-26 18:40:00 +01:00
|
|
|
for axis in [-1, len(shape) - 1]
|
|
|
|
for is_stable in [False, True]))
|
|
|
|
def testSortKeyVal(self, shape, key_dtype, val_dtype, axis, is_stable):
|
2020-05-14 19:17:44 -04:00
|
|
|
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
2020-07-14 13:03:24 -07:00
|
|
|
if (np.issubdtype(key_dtype, np.complexfloating) and (
|
2020-05-14 19:17:44 -04:00
|
|
|
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
|
|
|
jtu.device_under_test() == "tpu")):
|
|
|
|
raise SkipTest("Complex-valued sort not implemented")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
# This test relies on the property that wherever keys are tied, values are
|
|
|
|
# too, since we don't guarantee the same ordering of values with equal keys.
|
|
|
|
# To avoid that case, we generate unique keys (globally in the key array).
|
|
|
|
def args_maker():
|
2020-08-18 10:17:38 -07:00
|
|
|
flat_keys = np.arange(prod(shape), dtype=key_dtype)
|
2020-05-04 23:00:20 -04:00
|
|
|
keys = self.rng().permutation(flat_keys).reshape(shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
values = rng(shape, val_dtype)
|
|
|
|
return keys, values
|
|
|
|
|
2020-06-26 18:40:00 +01:00
|
|
|
fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-09 20:05:19 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_num_keys={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), num_keys),
|
|
|
|
"shape": shape, "dtype": dtype, "num_keys": num_keys}
|
|
|
|
for dtype in all_dtypes
|
|
|
|
for shape in [(3, 5,), (4, 3)]
|
|
|
|
for num_keys in range(1, shape[0] + 1)))
|
|
|
|
def testSortNumKeys(self, shape, dtype, num_keys):
|
|
|
|
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
2020-07-14 13:03:24 -07:00
|
|
|
if (np.issubdtype(dtype, np.complexfloating) and (
|
2020-07-09 20:05:19 -07:00
|
|
|
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
|
|
|
jtu.device_under_test() == "tpu")):
|
|
|
|
raise SkipTest("Complex-valued sort not implemented")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
lax_fun = lambda x: lax.sort(tuple(x), num_keys=num_keys)
|
2020-07-14 13:03:24 -07:00
|
|
|
numpy_fun = lambda x: tuple(x[:, np.lexsort(x[:num_keys][::-1])])
|
2020-07-09 20:05:19 -07:00
|
|
|
# self._CompileAndCheck(lax_fun, args_maker)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_fun, lax_fun, args_maker)
|
2020-07-09 20:05:19 -07:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, key_dtype),
|
|
|
|
jtu.format_shape_dtype_string(shape, val_dtype),
|
|
|
|
axis),
|
2020-05-14 19:17:44 -04:00
|
|
|
"shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype,
|
|
|
|
"axis": axis}
|
|
|
|
for key_dtype in float_dtypes + complex_dtypes + int_dtypes + uint_dtypes
|
2020-07-14 13:03:24 -07:00
|
|
|
for val_dtype in [np.float32, np.int32, np.uint32]
|
2018-11-17 18:03:33 -08:00
|
|
|
for shape in [(3,), (5, 3)]
|
2020-05-14 19:17:44 -04:00
|
|
|
for axis in [-1, len(shape) - 1]))
|
|
|
|
def testSortKeyValAgainstNumpy(self, shape, key_dtype, val_dtype, axis):
|
|
|
|
# TODO(b/141131288): enable complex-valued sorts on TPU.
|
2020-07-14 13:03:24 -07:00
|
|
|
if (np.issubdtype(key_dtype, np.complexfloating) and (
|
2020-05-14 19:17:44 -04:00
|
|
|
(jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 47)) or
|
|
|
|
jtu.device_under_test() == "tpu")):
|
|
|
|
raise SkipTest("Complex-valued sort not implemented")
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
# This test relies on the property that wherever keys are tied, values are
|
|
|
|
# too, since we don't guarantee the same ordering of values with equal keys.
|
|
|
|
# To avoid that case, we generate unique keys (globally in the key array).
|
|
|
|
def args_maker():
|
2020-08-18 10:17:38 -07:00
|
|
|
flat_keys = np.arange(prod(shape), dtype=key_dtype)
|
2020-05-04 23:00:20 -04:00
|
|
|
keys = self.rng().permutation(flat_keys).reshape(shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
values = rng(shape, val_dtype)
|
|
|
|
return keys, values
|
|
|
|
|
|
|
|
op = lambda ks, vs: lax.sort_key_val(ks, vs, axis)
|
|
|
|
numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis)
|
2020-08-03 17:17:48 +02:00
|
|
|
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-02-20 17:15:25 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_k={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), k),
|
|
|
|
"rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
|
2020-07-14 13:03:24 -07:00
|
|
|
for dtype in [np.float32, np.int32, np.uint32]
|
2020-02-20 17:15:25 -08:00
|
|
|
for shape in [(3,), (5, 3)]
|
|
|
|
for k in [1, 3]
|
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testTopK(self, shape, dtype, k, rng_factory):
|
|
|
|
def args_maker():
|
2020-08-18 10:17:38 -07:00
|
|
|
flat_values = np.arange(prod(shape), dtype=dtype)
|
2020-05-04 23:00:20 -04:00
|
|
|
values = self.rng().permutation(flat_values).reshape(shape)
|
2020-02-20 17:15:25 -08:00
|
|
|
return [values]
|
|
|
|
def reference_top_k(x):
|
2020-07-14 13:03:24 -07:00
|
|
|
bcast_idxs = np.broadcast_to(np.arange(shape[-1], dtype=np.int32), shape)
|
2020-02-20 17:15:25 -08:00
|
|
|
sorted_vals, sorted_idxs = lax_reference.sort_key_val(x, bcast_idxs)
|
|
|
|
return sorted_vals[..., :-k-1:-1], sorted_idxs[..., :-k-1:-1]
|
|
|
|
op = lambda vs: lax.top_k(vs, k=k)
|
|
|
|
self._CheckAgainstNumpy(op, reference_top_k, args_maker)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(op, args_maker)
|
2020-02-20 17:15:25 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype)),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
2019-11-11 12:51:15 -08:00
|
|
|
"rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for lhs_shape, rhs_shape in [((3, 2), (2, 4)),
|
|
|
|
((5, 3, 2), (5, 2, 4)),
|
|
|
|
((1, 2, 2, 3), (1, 2, 3, 1))]
|
|
|
|
for dtype in float_dtypes
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_small]))
|
|
|
|
def testBatchMatMul(self, lhs_shape, rhs_shape, dtype, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
arg_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(lax.batch_matmul, arg_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def testCollapse(self):
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def collapse_first_two(x):
|
|
|
|
return lax.collapse(x, 0, 2)
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
self.assertEqual((6,), collapse_first_two(np.zeros((2, 3))).shape)
|
|
|
|
self.assertEqual((6, 4), collapse_first_two(np.zeros((2, 3, 4))).shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
self.assertEqual((2, 3, 4),
|
2020-07-14 13:03:24 -07:00
|
|
|
collapse_first_two(np.zeros((1, 2, 3, 4))).shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-06 18:37:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
|
2019-11-11 12:51:15 -08:00
|
|
|
"shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, "rng_factory": rng_factory}
|
2018-11-17 18:03:33 -08:00
|
|
|
for dtype in all_dtypes
|
|
|
|
for shape, idxs, axes in [
|
2020-07-14 13:03:24 -07:00
|
|
|
[(3, 4, 5), (np.array([0, 2, 1]),), (0,)],
|
|
|
|
[(3, 4, 5), (np.array([-1, -2]),), (0,)],
|
|
|
|
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)],
|
|
|
|
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)],
|
2018-11-17 18:03:33 -08:00
|
|
|
]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testIndexTake(self, shape, dtype, idxs, axes, rng_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
2018-11-17 18:03:33 -08:00
|
|
|
rand_idxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
|
|
|
|
fun = lambda src, idxs: lax.index_take(src, idxs, axes)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-14 14:33:40 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
|
|
|
|
slice_sizes),
|
|
|
|
"shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
|
2019-11-11 12:51:15 -08:00
|
|
|
"slice_sizes": slice_sizes, "rng_factory": rng_factory,
|
|
|
|
"rng_idx_factory": rng_idx_factory}
|
2019-01-14 14:33:40 -05:00
|
|
|
for dtype in all_dtypes
|
|
|
|
for shape, idxs, dnums, slice_sizes in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
|
2019-03-01 12:19:00 -05:00
|
|
|
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1,)),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
2019-03-01 12:19:00 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
|
|
|
(2,)),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
2019-03-01 12:19:00 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
|
|
|
(1, 3)),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
2019-03-01 12:19:00 -05:00
|
|
|
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
|
|
|
(1, 3)),
|
2019-01-14 14:33:40 -05:00
|
|
|
]
|
2020-05-04 23:00:20 -04:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, high=max(shape))]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
2019-11-22 11:34:14 -08:00
|
|
|
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
|
2019-11-11 12:51:15 -08:00
|
|
|
rng_idx_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
|
|
|
rng_idx = rng_idx_factory(self.rng())
|
2019-01-14 14:33:40 -05:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
|
|
|
|
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-01-14 14:33:40 -05:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums,
|
|
|
|
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
2019-01-14 14:33:40 -05:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2019-01-14 14:33:40 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
2019-03-01 12:19:00 -05:00
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2019-01-14 14:33:40 -05:00
|
|
|
]
|
2020-05-04 23:00:20 -04:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
|
|
|
rng_idx = rng_idx_factory(self.rng())
|
2019-01-14 14:33:40 -05:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter_add, dimension_numbers=dnums)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-01-14 14:33:40 -05:00
|
|
|
|
2019-06-21 19:31:41 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums,
|
|
|
|
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
2019-06-21 19:31:41 -07:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
]
|
2020-05-04 23:00:20 -04:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
|
|
|
rng_idx = rng_idx_factory(self.rng())
|
2019-06-21 19:31:41 -07:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter_min, dimension_numbers=dnums)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-06-21 19:31:41 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums,
|
|
|
|
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
2019-06-21 19:31:41 -07:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2019-06-21 19:31:41 -07:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
]
|
2020-05-04 23:00:20 -04:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
|
|
|
rng_idx = rng_idx_factory(self.rng())
|
2019-06-21 19:31:41 -07:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter_max, dimension_numbers=dnums)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-06-21 19:31:41 -07:00
|
|
|
|
2019-03-01 15:41:49 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
|
|
|
jtu.format_shape_dtype_string(arg_shape, dtype),
|
|
|
|
idxs, update_shape, dnums),
|
|
|
|
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
2019-11-11 12:51:15 -08:00
|
|
|
"update_shape": update_shape, "dnums": dnums,
|
|
|
|
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
2019-03-01 15:41:49 -05:00
|
|
|
for dtype in float_dtypes
|
|
|
|
for arg_shape, idxs, update_shape, dnums in [
|
2020-07-14 13:03:24 -07:00
|
|
|
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
2019-03-01 15:41:49 -05:00
|
|
|
update_window_dims=(), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
2019-03-01 15:41:49 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
2020-07-14 13:03:24 -07:00
|
|
|
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
2019-03-01 15:41:49 -05:00
|
|
|
update_window_dims=(1,), inserted_window_dims=(0,),
|
|
|
|
scatter_dims_to_operand_dims=(0,))),
|
|
|
|
]
|
2020-05-04 23:00:20 -04:00
|
|
|
for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
|
2019-11-11 12:51:15 -08:00
|
|
|
for rng_factory in [jtu.rand_default]))
|
|
|
|
def testScatter(self, arg_shape, dtype, idxs, update_shape, dnums,
|
|
|
|
rng_factory, rng_idx_factory):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = rng_factory(self.rng())
|
|
|
|
rng_idx = rng_idx_factory(self.rng())
|
2019-03-01 15:41:49 -05:00
|
|
|
rand_idxs = lambda: rng_idx(idxs.shape, idxs.dtype)
|
|
|
|
args_maker = lambda: [rng(arg_shape, dtype), rand_idxs(),
|
|
|
|
rng(update_shape, dtype)]
|
|
|
|
fun = partial(lax.scatter, dimension_numbers=dnums)
|
2020-06-01 17:19:23 -04:00
|
|
|
self._CompileAndCheck(fun, args_maker)
|
2019-03-01 15:41:49 -05:00
|
|
|
|
2019-06-09 09:49:16 -07:00
|
|
|
def testIssue831(self):
|
|
|
|
# Tests the DeviceTuple constant handler
|
|
|
|
def f(x):
|
|
|
|
g = lambda *args: args[1]
|
|
|
|
return api.jit(lax.fori_loop, static_argnums=(2,))( 0, 10, g, x)
|
|
|
|
|
|
|
|
api.jit(f)(1.) # doesn't crash
|
|
|
|
|
2019-08-31 21:23:39 -07:00
|
|
|
def testReshapeWithUnusualShapes(self):
|
2020-07-14 13:03:24 -07:00
|
|
|
ans = lax.reshape(np.ones((3,), np.float32), (lax.add(1, 2), 1))
|
|
|
|
self.assertAllClose(ans, np.ones((3, 1), np.float32))
|
2019-08-31 21:23:39 -07:00
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
|
|
|
"Shapes must be 1D sequences of concrete values of integer type.*",
|
2020-07-14 13:03:24 -07:00
|
|
|
lambda: lax.reshape(np.ones(3,), (np.array([3, 1]),)))
|
2019-08-31 21:23:39 -07:00
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
|
|
|
"Shapes must be 1D sequences of concrete values of integer type.*",
|
2020-07-14 13:03:24 -07:00
|
|
|
lambda: lax.reshape(np.ones(3,), (1.5, 2.0)))
|
2019-08-31 21:23:39 -07:00
|
|
|
|
2019-11-14 15:51:27 -05:00
|
|
|
def testDynamicSliceTypeErrors(self):
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 15:51:27 -05:00
|
|
|
TypeError,
|
|
|
|
"index arguments to dynamic_slice must be integers of the same type",
|
2020-07-14 13:03:24 -07:00
|
|
|
lambda: lax.dynamic_slice(np.ones((3, 4), dtype=np.float32),
|
|
|
|
(np.int32(1), np.int16(2)), (2, 2)))
|
2019-11-14 15:51:27 -05:00
|
|
|
|
|
|
|
def testDynamicUpdateSliceTypeErrors(self):
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 15:51:27 -05:00
|
|
|
TypeError,
|
|
|
|
"index arguments to dynamic_update_slice must be integers of the same "
|
|
|
|
"type",
|
2020-07-14 13:03:24 -07:00
|
|
|
lambda: lax.dynamic_update_slice(np.ones((3, 4), dtype=np.float32),
|
|
|
|
np.zeros((2, 2), dtype=np.float32),
|
|
|
|
(np.int32(1), np.int16(2))))
|
2019-01-14 14:33:40 -05:00
|
|
|
|
2020-06-01 13:24:40 -07:00
|
|
|
def test_tie_in_error(self):
|
2020-07-30 12:59:36 -07:00
|
|
|
raise SkipTest("test no longer needed after trivializing tie_in")
|
|
|
|
# with core.skipping_checks():
|
|
|
|
# with self.assertRaisesRegex(
|
|
|
|
# TypeError, ".* of type .*tuple.* is not a valid JAX type"):
|
|
|
|
# api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)
|
2020-06-01 13:24:40 -07:00
|
|
|
|
|
|
|
def test_primitive_jaxtype_error(self):
|
|
|
|
with core.skipping_checks():
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError, "Argument .* of type .* is not a valid JAX type"):
|
|
|
|
lax.add(1, 'hi')
|
|
|
|
|
2020-06-30 21:18:46 -07:00
|
|
|
def test_reduction_with_repeated_axes_error(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "duplicate value in 'axes' .*"):
|
2020-07-14 13:03:24 -07:00
|
|
|
lax.reduce(np.arange(3), 0, lax.add, (0, 0))
|
2020-06-30 21:18:46 -07:00
|
|
|
|
2020-07-28 19:46:00 -07:00
|
|
|
def test_population_count_booleans_not_supported(self):
|
|
|
|
# https://github.com/google/jax/issues/3886
|
|
|
|
msg = "population_count does not accept dtype bool"
|
|
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
|
|
lax.population_count(True)
|
|
|
|
|
2020-06-30 21:18:46 -07:00
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
class LazyConstantTest(jtu.JaxTestCase):
|
|
|
|
def _Check(self, make_const, expected):
|
2018-12-18 22:45:34 -08:00
|
|
|
# check casting to ndarray works
|
2020-07-14 13:03:24 -07:00
|
|
|
asarray_result = np.asarray(make_const())
|
2018-12-18 22:45:34 -08:00
|
|
|
|
|
|
|
# check passing as an argument works (should hit constant handler)
|
2020-07-14 13:03:24 -07:00
|
|
|
zero = np.array(0, expected.dtype)
|
2018-12-18 22:45:34 -08:00
|
|
|
argument_result = lax.add(zero, make_const())
|
|
|
|
|
|
|
|
# check looping into a compiled computation works
|
|
|
|
jit_result = api.jit(lambda x: lax.add(x, make_const()))(zero)
|
|
|
|
|
|
|
|
# ensure they're all the same
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(asarray_result, expected)
|
|
|
|
self.assertAllClose(argument_result, expected)
|
|
|
|
self.assertAllClose(jit_result, expected)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
2019-05-17 12:38:45 -07:00
|
|
|
# ensure repr doesn't crash
|
|
|
|
repr(make_const())
|
|
|
|
|
2018-12-18 22:45:34 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_fill={}".format(
|
2019-02-27 12:20:30 -08:00
|
|
|
jtu.format_shape_dtype_string(shape, dtype) if dtype else shape,
|
|
|
|
fill_value),
|
2018-12-18 22:45:34 -08:00
|
|
|
"shape": shape, "dtype": dtype, "fill_value": fill_value}
|
2019-02-27 12:20:30 -08:00
|
|
|
for dtype in itertools.chain(default_dtypes, [None])
|
2019-05-17 13:01:45 -07:00
|
|
|
for shape in [(), (3,), (2, 3), (2, 3, 4), (1001, 1001)]
|
2020-07-14 13:03:24 -07:00
|
|
|
for fill_value in [0, 1, np.pi]))
|
2018-12-18 22:45:34 -08:00
|
|
|
def testFilledConstant(self, shape, fill_value, dtype):
|
|
|
|
make_const = lambda: lax.full(shape, fill_value, dtype)
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.full(shape, fill_value,
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
dtype or dtypes.result_type(fill_value))
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
self._Check(make_const, expected)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_dim={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), dimension),
|
|
|
|
"shape": shape, "dtype": dtype, "dimension": dimension}
|
|
|
|
for dtype in default_dtypes
|
2019-05-19 12:44:51 -07:00
|
|
|
for shape in [(), (3,), (2, 3), (2, 3, 4),
|
|
|
|
# TODO(mattjj): re-enable
|
|
|
|
# (1001, 1001), (101, 101, 101),
|
|
|
|
]
|
2018-12-18 22:45:34 -08:00
|
|
|
for dimension in range(len(shape))))
|
|
|
|
def testIotaConstant(self, dtype, shape, dimension):
|
|
|
|
make_const = lambda: lax.broadcasted_iota(dtype, shape, dimension)
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
arr = np.arange(shape[dimension], dtype=dtypes.canonicalize_dtype(dtype))
|
2018-12-18 22:45:34 -08:00
|
|
|
singleton_shape = [1] * len(shape)
|
|
|
|
singleton_shape[dimension] = shape[dimension]
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.broadcast_to(arr.reshape(singleton_shape), shape)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
self._Check(make_const, expected)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_axes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axes),
|
|
|
|
"shape": shape, "dtype": dtype, "axes": axes}
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for shape, axes in [
|
|
|
|
[(2, 3), (0, 1)],
|
|
|
|
[(2, 3, 4), (0, 1)],
|
|
|
|
[(2, 3, 4), (0, 2)],
|
|
|
|
[(2, 3, 4), (1, 2)],
|
|
|
|
[(2, 3, 4), (0, 1, 2)],
|
|
|
|
[(2, 3, 4, 2), (0, 1, 2)],
|
|
|
|
[(2, 3, 4, 2), (0, 2, 3)],
|
2019-05-17 13:01:45 -07:00
|
|
|
[(1001, 1001), (0, 1)],
|
2018-12-18 22:45:34 -08:00
|
|
|
]))
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
@jtu.skip_on_devices("tpu") # TODO(mattjj): investigate failure
|
|
|
|
def testDeltaConstant(self, dtype, shape, axes):
|
|
|
|
make_const = lambda: lax._delta(dtype, shape, axes)
|
2018-12-18 22:45:34 -08:00
|
|
|
# don't check the asarray case, just assume it's right
|
2020-07-14 13:03:24 -07:00
|
|
|
expected = np.asarray(make_const())
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
self._Check(make_const, expected)
|
2018-12-18 22:45:34 -08:00
|
|
|
|
2020-06-25 15:50:11 +01:00
|
|
|
def testBroadcastInDim(self):
|
|
|
|
arr = lax.full((2, 1), 1.) + 1.
|
2020-07-14 13:03:24 -07:00
|
|
|
arr_np = np.full((2, 1), 1.) + 1.
|
|
|
|
expected = lax_reference.broadcast_in_dim(arr_np, (2, 1, 3), (0, 2))
|
2020-06-25 15:50:11 +01:00
|
|
|
make_const = lambda: lax.broadcast_in_dim(arr, (2, 1, 3), (0, 2))
|
|
|
|
self._Check(make_const, expected)
|
|
|
|
|
2018-12-18 22:45:34 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
if __name__ == '__main__':
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|