rocm_jax/tests/lax_test.py
2018-11-18 15:15:47 -08:00

1982 lines
86 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from functools import partial
import itertools
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
import numpy as onp
import numpy.random as npr
from jax import api
from jax import core
from jax import lax
from jax import test_util as jtu
from jax import lax_reference
from jax.interpreters import xla
from jax.lib import xla_bridge
FLAGS = flags.FLAGS
def num_float_bits(dtype):
return onp.finfo(xla_bridge.canonicalize_dtype(dtype)).bits
### lax tests
# For standard unops and binops, we can generate a large number of tests on
# arguments of appropriate shapes and dtypes using the following table.
float_dtypes = [onp.float32, onp.float64]
complex_dtypes = [onp.complex64]
int_dtypes = [onp.int32, onp.int64]
bool_dtypes = [onp.bool_]
default_dtypes = float_dtypes + int_dtypes
all_dtypes = float_dtypes + complex_dtypes + int_dtypes + bool_dtypes
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
OpRecord = collections.namedtuple("OpRecord",
["op", "nargs", "dtypes", "rng", "tol"])
def op_record(op, nargs, dtypes, rng, tol=1e-5):
return OpRecord(op, nargs, dtypes, rng, tol)
LAX_OPS = [
op_record(lax.neg, 1, default_dtypes + complex_dtypes, jtu.rand_small()),
op_record(lax.sign, 1, default_dtypes, jtu.rand_small()),
op_record(lax.floor, 1, float_dtypes, jtu.rand_small()),
op_record(lax.ceil, 1, float_dtypes, jtu.rand_small()),
op_record(lax.round, 1, float_dtypes, jtu.rand_default()),
op_record(lax.is_finite, 1, float_dtypes, jtu.rand_small()),
op_record(lax.exp, 1, float_dtypes + complex_dtypes, jtu.rand_small()),
op_record(lax.expm1, 1, float_dtypes + complex_dtypes, jtu.rand_small()),
op_record(lax.log, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.log1p, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.tanh, 1, float_dtypes + complex_dtypes, jtu.rand_small()),
op_record(lax.sin, 1, float_dtypes + complex_dtypes, jtu.rand_default()),
op_record(lax.cos, 1, float_dtypes + complex_dtypes, jtu.rand_default()),
op_record(lax.atan2, 2, float_dtypes, jtu.rand_default()),
op_record(lax.sqrt, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.rsqrt, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.square, 1, float_dtypes + complex_dtypes, jtu.rand_default()),
op_record(lax.reciprocal, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.tan, 1, float_dtypes, jtu.rand_default()),
op_record(lax.asin, 1, float_dtypes, jtu.rand_small()),
op_record(lax.acos, 1, float_dtypes, jtu.rand_small()),
op_record(lax.atan, 1, float_dtypes, jtu.rand_small()),
op_record(lax.sinh, 1, float_dtypes + complex_dtypes, jtu.rand_default()),
op_record(lax.cosh, 1, float_dtypes + complex_dtypes, jtu.rand_default()),
op_record(lax.asinh, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.acosh, 1, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.lgamma, 1, float_dtypes, jtu.rand_positive()),
op_record(lax.digamma, 1, float_dtypes, jtu.rand_positive()),
op_record(lax.erf, 1, float_dtypes, jtu.rand_small()),
op_record(lax.erfc, 1, float_dtypes, jtu.rand_small()),
op_record(lax.erf_inv, 1, float_dtypes, jtu.rand_small(), tol=1e-2),
op_record(lax.real, 1, complex_dtypes, jtu.rand_default()),
op_record(lax.imag, 1, complex_dtypes, jtu.rand_default()),
op_record(lax.complex, 2, [onp.float32], jtu.rand_default()),
op_record(lax.conj, 1, [onp.float32] + complex_dtypes, jtu.rand_default()),
op_record(lax.abs, 1, default_dtypes + complex_dtypes, jtu.rand_default()),
op_record(lax.pow, 2, float_dtypes + complex_dtypes, jtu.rand_positive()),
op_record(lax.bitwise_and, 2, bool_dtypes, jtu.rand_small()),
op_record(lax.bitwise_not, 1, bool_dtypes, jtu.rand_small()),
op_record(lax.bitwise_or, 2, bool_dtypes, jtu.rand_small()),
op_record(lax.bitwise_xor, 2, bool_dtypes, jtu.rand_small()),
op_record(lax.add, 2, default_dtypes + complex_dtypes, jtu.rand_small()),
op_record(lax.sub, 2, default_dtypes + complex_dtypes, jtu.rand_small()),
op_record(lax.mul, 2, default_dtypes + complex_dtypes, jtu.rand_small()),
op_record(lax.div, 2, default_dtypes + complex_dtypes, jtu.rand_nonzero()),
op_record(lax.rem, 2, default_dtypes, jtu.rand_nonzero()),
op_record(lax.max, 2, default_dtypes, jtu.rand_small()),
op_record(lax.min, 2, default_dtypes, jtu.rand_small()),
op_record(lax.eq, 2, all_dtypes, jtu.rand_some_equal()),
op_record(lax.ne, 2, all_dtypes, jtu.rand_small()),
op_record(lax.ge, 2, default_dtypes, jtu.rand_small()),
op_record(lax.gt, 2, default_dtypes, jtu.rand_small()),
op_record(lax.le, 2, default_dtypes, jtu.rand_small()),
op_record(lax.lt, 2, default_dtypes, jtu.rand_small()),
]
CombosWithReplacement = itertools.combinations_with_replacement
class LaxTest(jtu.JaxTestCase):
"""Numerical tests for LAX operations."""
@parameterized.named_parameters(
{"testcase_name": jtu.format_test_name_suffix(
rec.op.__name__, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype}
for rec in LAX_OPS
for shape_group in compatible_shapes
for shapes in CombosWithReplacement(shape_group, rec.nargs)
for dtype in rec.dtypes)
def testOp(self, op, rng, shapes, dtype):
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": jtu.format_test_name_suffix(
rec.op.__name__, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
"tol": rec.tol}
for rec in LAX_OPS
for shape_group in compatible_shapes
for shapes in CombosWithReplacement(shape_group, rec.nargs)
for dtype in rec.dtypes)
def testOpAgainstNumpy(self, op, rng, shapes, dtype, tol):
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
numpy_op = getattr(lax_reference, op.__name__)
self._CheckAgainstNumpy(op, numpy_op, args_maker, tol=tol)
# TODO test shift_left, shift_right_arithmetic, shift_right_logical
@parameterized.named_parameters(
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
from_dtype, to_dtype),
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng": rng}
for from_dtype, to_dtype in itertools.product(
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
for rng in [jtu.rand_default()])
def testConvertElementType(self, from_dtype, to_dtype, rng):
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.convert_element_type(x, to_dtype)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_from_dtype={}_to_dtype={}"
.format(from_dtype, to_dtype),
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng": rng}
for from_dtype, to_dtype in itertools.product(
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
for rng in [jtu.rand_default()])
def testConvertElementTypeAgainstNumpy(self, from_dtype, to_dtype, rng):
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.convert_element_type(x, to_dtype)
numpy_op = lambda x: lax_reference.convert_element_type(x, to_dtype)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_from_dtype={}_to_dtype={}"
.format(from_dtype, to_dtype),
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng": rng}
for from_dtype, to_dtype in itertools.product(
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
for rng in [jtu.rand_default()])
def testBitcastConvertType(self, from_dtype, to_dtype, rng):
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_from_dtype={}_to_dtype={}"
.format(from_dtype, to_dtype),
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng": rng}
for from_dtype, to_dtype in itertools.product(
[onp.float32, onp.int32, "float32", "int32"], repeat=2)
for rng in [jtu.rand_default()])
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, rng):
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
numpy_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
jtu.format_shape_dtype_string(min_shape, dtype),
jtu.format_shape_dtype_string(operand_shape, dtype),
jtu.format_shape_dtype_string(max_shape, dtype)),
"min_shape": min_shape, "operand_shape": operand_shape,
"max_shape": max_shape, "dtype": dtype, "rng": rng}
for min_shape, operand_shape, max_shape in [
[(), (2, 3), ()],
[(2, 3), (2, 3), ()],
[(), (2, 3), (2, 3)],
[(2, 3), (2, 3), (2, 3)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testClamp(self, min_shape, operand_shape, max_shape, dtype, rng):
shapes = [min_shape, operand_shape, max_shape]
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
self._CompileAndCheck(lax.clamp, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
jtu.format_shape_dtype_string(min_shape, dtype),
jtu.format_shape_dtype_string(operand_shape, dtype),
jtu.format_shape_dtype_string(max_shape, dtype)),
"min_shape": min_shape, "operand_shape": operand_shape,
"max_shape": max_shape, "dtype": dtype, "rng": rng}
for min_shape, operand_shape, max_shape in [
[(), (2, 3), ()],
[(2, 3), (2, 3), ()],
[(), (2, 3), (2, 3)],
[(2, 3), (2, 3), (2, 3)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testClampAgainstNumpy(self, min_shape, operand_shape, max_shape, dtype,
rng):
shapes = [min_shape, operand_shape, max_shape]
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
self._CheckAgainstNumpy(lax.clamp, lax_reference.clamp, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
num_arrs),
"dim": dim, "base_shape": base_shape, "dtype": dtype,
"num_arrs": num_arrs, "rng": rng}
for num_arrs in [3]
for dtype in default_dtypes
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for dim in range(len(base_shape))
for rng in [jtu.rand_default()])
def testConcatenate(self, dim, base_shape, dtype, num_arrs, rng):
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
op = lambda *args: lax.concatenate(args, dim)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
num_arrs),
"dim": dim, "base_shape": base_shape, "dtype": dtype,
"num_arrs": num_arrs, "rng": rng}
for num_arrs in [3]
for dtype in default_dtypes
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for dim in range(len(base_shape))
for rng in [jtu.rand_default()])
def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs, rng):
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
op = lambda *args: lax.concatenate(args, dim)
numpy_op = lambda *args: lax_reference.concatenate(args, dim)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "rng": rng}
for lhs_shape, rhs_shape in [
((b, i, 9, 10), (j, i, 4, 5))
for b, i, j in itertools.product([2, 3], repeat=3)]
for dtype in [onp.float32]
for strides in [(1, 1), (1, 2), (2, 1)]
for padding in ["VALID", "SAME"]
for rng in [jtu.rand_small()])
def testConv(self, lhs_shape, rhs_shape, dtype, strides, padding, rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
def fun(lhs, rhs):
return lax.conv(lhs, rhs, strides, padding)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "rng": rng}
for lhs_shape, rhs_shape in [
((b, i, 9, 10), (j, i, 4, 5))
for b, i, j in itertools.product([2, 3], repeat=3)]
for dtype in [onp.float32]
for strides in [(1, 1), (1, 2), (2, 1)]
for padding in ["VALID", "SAME"]
for rng in [jtu.rand_small()])
def testConvAgainstNumpy(self, lhs_shape, rhs_shape, dtype, strides, padding,
rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
op = lambda lhs, rhs: lax.conv(lhs, rhs, strides, padding)
numpy_op = lambda lhs, rhs: lax_reference.conv(lhs, rhs, strides, padding)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
"_lhs_dilation={}_rhs_dilation={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dilation, rhs_dilation),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
"rhs_dilation": rhs_dilation, "rng": rng}
for lhs_shape, rhs_shape in [
((b, i, 9, 10), (j, i, 4, 5))
for b, i, j in itertools.product([1, 2, 3], repeat=3)]
for dtype in [onp.float32] for strides in [(1, 1), (1, 2), (2, 1)]
for padding in [((0, 0), (0, 0)), ((1, 2), (2, 0))]
for lhs_dilation, rhs_dilation in itertools.product(
[(1, 1), (1, 2), (2, 2)], repeat=2)
for rng in [jtu.rand_small()])
def testConvWithGeneralPadding(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dilation, rhs_dilation, rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
def fun(lhs, rhs):
return lax.conv_with_general_padding(
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
"_lhs_dilation={}_rhs_dilation={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dilation, rhs_dilation),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
"rhs_dilation": rhs_dilation, "rng": rng}
for lhs_shape, rhs_shape in [
((b, i, 9, 10), (j, i, 4, 5))
for b, i, j in itertools.product([1, 2, 3], repeat=3)]
for dtype in [onp.float32] for strides in [(1, 1), (1, 2), (2, 1)]
for padding in [((0, 0), (0, 0)), ((1, 2), (2, 0))]
for lhs_dilation, rhs_dilation in itertools.product(
[(1, 1), (1, 2), (2, 2)], repeat=2)
for rng in [jtu.rand_small()])
def DISABLED_testConvWithGeneralPaddingAgainstNumpy(
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation,
rhs_dilation, rng):
# TODO(mattjj): make this test pass
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
def fun(lhs, rhs):
return lax.conv_with_general_padding(
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
def numpy_fun(lhs, rhs):
return lax_reference.conv_with_general_padding(
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
self._CheckAgainstNumpy(fun, numpy_fun, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
"_lhs_dilation={}_rhs_dilation={}"
"_dims={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dilation, rhs_dilation,
",".join(dim_nums)),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
"rhs_dilation": rhs_dilation, "dimension_numbers": dim_nums,
"perms": perms, "rng": rng}
for lhs_shape, rhs_shape in [
((b, i, 9, 10), (j, i, 4, 5))
for b, i, j in itertools.product([2, 3], repeat=3)]
for dtype in [onp.float32] for strides in [(1, 1), (2, 1)]
for padding in [((1, 2), (2, 0))]
for lhs_dilation, rhs_dilation in itertools.product(
[(1, 1), (1, 2)], repeat=2)
for rng in [jtu.rand_small()]
for dim_nums, perms in [(("NCHW", "OIHW", "NCHW"),
([0, 1, 2, 3], [0, 1, 2, 3])),
(("NHWC", "HWIO", "NHWC"),
([0, 2, 3, 1], [2, 3, 1, 0]))])
def testConvGeneralDilated(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dilation, rhs_dilation,
dimension_numbers, perms, rng):
lhs_perm, rhs_perm = perms # permute to compatible shapes
def args_maker():
return [lax.transpose(rng(lhs_shape, dtype), lhs_perm),
lax.transpose(rng(rhs_shape, dtype), rhs_perm)]
def fun(lhs, rhs):
return lax.conv_general_dilated(
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
# TODO(mattjj): test conv_general_dilated against numpy
@parameterized.named_parameters(
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype)),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"rng": rng}
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
for dtype in float_dtypes
for rng in [jtu.rand_default()])
def testDot(self, lhs_shape, rhs_shape, dtype, rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CompileAndCheck(lax.dot, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype)),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"rng": rng}
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
for dtype in float_dtypes
for rng in [jtu.rand_default()])
def testDotAgainstNumpy(self, lhs_shape, rhs_shape, dtype, rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CheckAgainstNumpy(lax.dot, lax_reference.dot, args_maker)
@parameterized.named_parameters(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
lhs_contracting, rhs_contracting),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
"rng": rng}
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
# these all fail with "RuntimeError: Unimplemented: Dot with
# non-standard contracting dimensions not implemented."
# [(3, 5), (2, 5), [1], [1]],
# [(5, 3), (5, 2), [0], [0]],
# [(5, 3, 2), (5, 2, 4), [0], [0]],
# [(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
# [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
[(3, 2), (2, 4), [1], [0]],
]
for dtype in float_dtypes
for rng in [jtu.rand_small()])
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
lhs_contracting, rhs_contracting, rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
def fun(lhs, rhs):
return lax.dot_general(lhs, rhs, dimension_numbers)
self._CompileAndCheck(fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
dimension_numbers),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"dimension_numbers": dimension_numbers, "rng": rng}
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
for dtype in float_dtypes
for rng in [jtu.rand_small()])
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
dimension_numbers, rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
def fun(lhs, rhs):
return lax.dot_general(lhs, rhs, dimension_numbers)
self._CompileAndCheck(fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
dimension_numbers),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"dimension_numbers": dimension_numbers, "rng": rng}
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
for dtype in float_dtypes
for rng in [jtu.rand_small()])
def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype,
dimension_numbers, rng):
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
op = lambda x, y: lax.dot_general(x, y, dimension_numbers)
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
shape, onp.dtype(dtype).name, broadcast_sizes),
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
"rng": rng}
for shape in [(), (2, 3)]
for dtype in default_dtypes
for broadcast_sizes in [(), (2,), (1, 2)]
for rng in [jtu.rand_default()])
def testBroadcast(self, shape, dtype, broadcast_sizes, rng):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.broadcast(x, broadcast_sizes)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_broadcast_sizes={}".format(
jtu.format_shape_dtype_string(shape, dtype), broadcast_sizes),
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
"rng": rng}
for shape in [(), (2, 3)]
for dtype in default_dtypes
for broadcast_sizes in [(), (2,), (1, 2)]
for rng in [jtu.rand_default()])
def testBroadcastAgainstNumpy(self, shape, dtype, broadcast_sizes, rng):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.broadcast(x, broadcast_sizes)
numpy_op = lambda x: lax_reference.broadcast(x, broadcast_sizes)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
jtu.format_shape_dtype_string(inshape, dtype),
outshape, broadcast_dimensions),
"inshape": inshape, "dtype": dtype, "outshape": outshape,
"dimensions": broadcast_dimensions, "rng": rng}
for inshape, outshape, broadcast_dimensions in [
([2], [2, 2], [0]),
([2], [2, 2], [1]),
([2], [2, 3], [0]),
([], [2, 3], []),
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, rng):
args_maker = lambda: [rng(inshape, dtype)]
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
jtu.format_shape_dtype_string(inshape, dtype),
outshape, broadcast_dimensions),
"inshape": inshape, "dtype": dtype, "outshape": outshape,
"dimensions": broadcast_dimensions, "rng": rng}
for inshape, outshape, broadcast_dimensions in [
([2], [2, 2], [0]),
([2], [2, 2], [1]),
([2], [2, 3], [0]),
([], [2, 3], []),
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testBroadcastInDimAgainstNumpy(self, inshape, dtype, outshape,
dimensions, rng):
args_maker = lambda: [rng(inshape, dtype)]
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_inshape={}_outshape={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype)),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
"rng": rng}
for dtype in default_dtypes
for arg_shape, out_shape in [
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
]
for rng in [jtu.rand_default()])
def testReshape(self, arg_shape, out_shape, dtype, rng):
args_maker = lambda: [rng(arg_shape, dtype)]
op = lambda x: lax.reshape(x, out_shape)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_inshape={}_outshape={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype)),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
"rng": rng}
for dtype in default_dtypes
for arg_shape, out_shape in [
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
]
for rng in [jtu.rand_default()])
def testReshapeAgainstNumpy(self, arg_shape, out_shape, dtype, rng):
args_maker = lambda: [rng(arg_shape, dtype)]
op = lambda x: lax.reshape(x, out_shape)
numpy_op = lambda x: lax_reference.reshape(x, out_shape)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_inshape={}_pads={}"
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
"shape": shape, "dtype": dtype, "pads": pads, "rng": jtu.rand_small()}
for shape in [(2, 3)]
for dtype in default_dtypes
for pads in [[(1, 2, 1), (0, 1, 0)]])
def testPad(self, shape, dtype, pads, rng):
args_maker = lambda: [rng(shape, dtype)]
fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_inshape={}_pads={}"
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
"shape": shape, "dtype": dtype, "pads": pads, "rng": jtu.rand_small()}
for shape in [(2, 3)]
for dtype in default_dtypes
for pads in [[(1, 2, 1), (0, 1, 0)]])
def testPadAgainstNumpy(self, shape, dtype, pads, rng):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.pad(x, onp.array(0, dtype), pads)
numpy_op = lambda x: lax_reference.pad(x, onp.array(0, dtype), pads)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
def testReverse(self):
rev = api.jit(lambda operand: lax.rev(operand, dimensions))
dimensions = [0]
self.assertAllClose(onp.array([3, 2, 1]), rev(onp.array([1, 2, 3])),
check_dtypes=False)
dimensions = [0, 1]
self.assertAllClose(onp.array([[6, 5, 4], [3, 2, 1]]),
rev(onp.array([[1, 2, 3], [4, 5, 6]])),
check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": "_predshape={}_argshapes={}".format(
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
jtu.format_shape_dtype_string(arg_shape, arg_dtype)),
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
"rng": rng}
for arg_shape in [(), (3,), (2, 3)]
for pred_shape in ([(), arg_shape] if arg_shape else [()])
for arg_dtype in default_dtypes
for rng in [jtu.rand_default()])
def testSelect(self, pred_shape, arg_shape, arg_dtype, rng):
def args_maker():
return [rng(pred_shape, onp.bool_), rng(arg_shape, arg_dtype),
rng(arg_shape, arg_dtype)]
return self._CompileAndCheck(lax.select, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_predshape={}_argshapes={}".format(
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
jtu.format_shape_dtype_string(arg_shape, arg_dtype)),
"pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
"rng": rng}
for arg_shape in [(), (3,), (2, 3)]
for pred_shape in ([(), arg_shape] if arg_shape else [()])
for arg_dtype in default_dtypes
for rng in [jtu.rand_default()])
def testSelectAgainstNumpy(self, pred_shape, arg_shape, arg_dtype, rng):
def args_maker():
return [rng(pred_shape, onp.bool_), rng(arg_shape, arg_dtype),
rng(arg_shape, arg_dtype)]
return self._CheckAgainstNumpy(lax.select, lax_reference.select, args_maker)
@parameterized.named_parameters(
{"testcase_name":
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, limit_indices, strides),
"shape": shape, "dtype": dtype, "starts": start_indices,
"limits": limit_indices, "strides": strides, "rng": rng}
for shape, start_indices, limit_indices, strides in [
[(3,), (1,), (2,), None],
[(7,), (4,), (7,), None],
[(5,), (1,), (5,), (2,)],
[(8,), (1,), (6,), (2,)],
[(5, 3), (1, 1), (3, 2), None],
[(5, 3), (1, 1), (3, 1), None],
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
[(5, 3), (1, 1), (2, 1), (1, 1)],
[(5, 3), (1, 1), (5, 3), (2, 1)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testSlice(self, shape, dtype, starts, limits, strides, rng):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.slice(x, starts, limits, strides)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name":
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, limit_indices, strides),
"shape": shape, "dtype": dtype, "starts": start_indices,
"limits": limit_indices, "strides": strides, "rng": rng}
for shape, start_indices, limit_indices, strides in [
[(3,), (1,), (2,), None],
[(7,), (4,), (7,), None],
[(5,), (1,), (5,), (2,)],
[(8,), (1,), (6,), (2,)],
[(5, 3), (1, 1), (3, 2), None],
[(5, 3), (1, 1), (3, 1), None],
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
[(5, 3), (1, 1), (2, 1), (1, 1)],
[(5, 3), (1, 1), (5, 3), (2, 1)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testSliceAgainstNumpy(self, shape, dtype, starts, limits,
strides, rng):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.slice(x, starts, limits, strides)
numpy_op = lambda x: lax_reference.slice(x, starts, limits, strides)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, size_indices),
"shape": shape, "dtype": dtype, "start_indices": start_indices,
"size_indices": size_indices, "rng": rng}
for shape, start_indices, size_indices in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testDynamicSlice(self, shape, dtype, start_indices, size_indices, rng):
args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)]
op = lambda x, starts: lax.dynamic_slice(x, starts, size_indices)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, size_indices),
"shape": shape, "dtype": dtype, "start_indices": start_indices,
"size_indices": size_indices, "rng": rng}
for shape, start_indices, size_indices in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testDynamicSliceAgainstNumpy(self, shape, dtype, start_indices,
size_indices, rng):
args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)]
op = lambda x, s: lax.dynamic_slice(x, s, size_indices)
numpy_op = lambda x, s: lax_reference.dynamic_slice(x, s, size_indices)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, update_shape),
"shape": shape, "dtype": dtype, "start_indices": start_indices,
"update_shape": update_shape, "rng": rng}
for shape, start_indices, update_shape in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testDynamicUpdateSlice(self, shape, dtype, start_indices, update_shape,
rng):
def args_maker():
return [rng(shape, dtype), rng(update_shape, dtype),
onp.array(start_indices)]
self._CompileAndCheck(lax.dynamic_update_slice, args_maker,
check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, update_shape),
"shape": shape, "dtype": dtype, "start_indices": start_indices,
"update_shape": update_shape, "rng": rng}
for shape, start_indices, update_shape in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testDynamicUpdateSlice(self, shape, dtype, start_indices, update_shape,
rng):
def args_maker():
return [rng(shape, dtype), rng(update_shape, dtype),
onp.array(start_indices)]
self._CheckAgainstNumpy(lax.dynamic_update_slice,
lax_reference.dynamic_update_slice, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_perm={}".format(
jtu.format_shape_dtype_string(shape, dtype), perm),
"shape": shape, "dtype": dtype, "perm": perm, "rng": rng}
for shape, perm in [
[(3, 4), (1, 0)],
[(3, 4), (0, 1)],
[(3, 4, 5), (2, 1, 0)],
[(3, 4, 5), (1, 0, 2)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testTranspose(self, shape, dtype, perm, rng):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.transpose(x, perm)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_perm={}".format(
jtu.format_shape_dtype_string(shape, dtype), perm),
"shape": shape, "dtype": dtype, "perm": perm, "rng": rng}
for shape, perm in [
[(3, 4), (1, 0)],
[(3, 4), (0, 1)],
[(3, 4, 5), (2, 1, 0)],
[(3, 4, 5), (1, 0, 2)],
]
for dtype in default_dtypes
for rng in [jtu.rand_default()])
def testTransposeAgainstNumpy(self, shape, dtype, perm, rng):
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.transpose(x, perm)
numpy_op = lambda x: lax_reference.transpose(x, perm)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(
{"testcase_name": "_op={}_inshape={}_reducedims={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
"dims": dims, "rng": rng}
for init_val, op, dtypes in [
(0, lax.add, default_dtypes),
(-onp.inf, lax.max, float_dtypes),
(onp.iinfo(onp.int32).min, lax.max, [onp.int32]),
(onp.iinfo(onp.int64).min, lax.max, [onp.int64]),
(onp.iinfo(onp.uint32).min, lax.max, [onp.uint32]),
(onp.iinfo(onp.uint64).min, lax.max, [onp.uint64]),
(onp.inf, lax.min, float_dtypes),
(onp.iinfo(onp.int32).max, lax.min, [onp.int32]),
(onp.iinfo(onp.int64).max, lax.min, [onp.int64]),
(onp.iinfo(onp.uint32).max, lax.min, [onp.uint32]),
(onp.iinfo(onp.uint64).max, lax.min, [onp.uint64]),
]
for dtype in dtypes
for shape, dims in [
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
]
for rng in [jtu.rand_small()])
def testReduce(self, op, init_val, shape, dtype, dims, rng):
init_val = onp.asarray(init_val, dtype=dtype)
fun = lambda operand, init_val: lax.reduce(operand, init_val, op, dims)
args_maker = lambda: [rng(shape, dtype), init_val]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
# we separately test the version that uses a concrete init_val because it
# can hit different code paths
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_op={}_dtype={}_padding={}"
.format(op.__name__, onp.dtype(dtype).name, padding),
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
"rng": rng}
for init_val, op, dtypes in [
(0, lax.add, [onp.float32]),
(-onp.inf, lax.max, [onp.float32]),
(onp.inf, lax.min, [onp.float32]),
]
for dtype in dtypes
for padding in ["VALID", "SAME"]
for rng in [jtu.rand_small()])
def testReduceWindow(self, op, init_val, dtype, padding, rng):
init_val = onp.asarray(init_val, dtype=dtype)
# We need this conditional and the corresponding loop logic to be in the
# test method, rather than at the parameterized test level, because it
# depends on FLAGS for the device under test.
if FLAGS.jax_test_dut == "tpu":
all_configs = [((4, 6), (2, 1), (1, 2))]
else:
all_configs = itertools.chain(
itertools.product(
[(4, 6)],
[(2, 1), (1, 2)],
[(1, 1), (2, 1), (1, 2)]),
itertools.product(
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
[(1, 2, 2, 1), (1, 1, 1, 1)]))
def fun(operand, init_val):
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
# pylint: disable=cell-var-from-loop
for shape, dims, strides in all_configs:
args_maker = lambda: [rng(shape, dtype), init_val]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
# pylint: enable=cell-var-from-loop
# we separately test the version that uses a concrete init_val because it
# can hit different code paths
def fun(operand):
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
# pylint: disable=cell-var-from-loop
for shape, dims, strides in all_configs:
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
# pylint: enable=cell-var-from-loop
# TODO(b/205052657): enable more tests when supported
@parameterized.named_parameters(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"rng": rng, "shape": shape, "dtype": dtype, "axis": axis}
for dtype in [onp.float32, onp.int32, onp.uint32]
for shape in [(5,), (5, 7)]
for axis in [-1, len(shape) - 1]
for rng in [jtu.rand_default()])
def testSort(self, shape, dtype, axis, rng):
if len(shape) > 1 and not FLAGS.jax_test_dut.startswith("tpu"):
msg = "sort only implemented for R1 on non-TPU backends"
return absltest.unittest.skip(msg)
args_maker = lambda: [rng(shape, dtype)]
fun = lambda x: lax.sort(x, axis)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
# TODO(b/205052657): enable more tests when supported
@parameterized.named_parameters(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"rng": rng, "shape": shape, "dtype": dtype, "axis": axis}
for dtype in [onp.float32, onp.int32, onp.uint32]
for shape in [(5,), (5, 7)]
for axis in [-1, len(shape) - 1]
for rng in [jtu.rand_default()])
def testSortAgainstNumpy(self, shape, dtype, axis, rng):
if len(shape) > 1 and not FLAGS.jax_test_dut.startswith("tpu"):
msg = "sort only implemented for R1 on non-TPU backends"
return absltest.unittest.skip(msg)
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.sort(x, axis)
numpy_op = lambda x: lax_reference.sort(x, axis)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
# TODO(b/205052657): enable more tests when supported
@parameterized.named_parameters(
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, key_dtype),
jtu.format_shape_dtype_string(shape, val_dtype),
axis),
"rng": rng, "shape": shape,
"key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis}
for key_dtype in [onp.float32, onp.int32, onp.uint32]
for val_dtype in [onp.float32, onp.int32, onp.uint32]
for shape in [(3,), (5, 3)]
for axis in [-1, len(shape) - 1]
for rng in [jtu.rand_default()])
def testSortKeyVal(self, shape, key_dtype, val_dtype, axis, rng):
if len(shape) > 1 and not FLAGS.jax_test_dut.startswith("tpu"):
msg = "sort_key_val only implemented for R1 non-TPU backends"
return absltest.unittest.skip(msg)
# This test relies on the property that wherever keys are tied, values are
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
perm_rng = onp.random.RandomState(0)
def args_maker():
flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
keys = perm_rng.permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
fun = lambda keys, values: lax.sort_key_val(keys, values, axis)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
# TODO(b/205052657): enable more tests when supported
@parameterized.named_parameters(
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, key_dtype),
jtu.format_shape_dtype_string(shape, val_dtype),
axis),
"rng": rng, "shape": shape,
"key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis}
for key_dtype in [onp.float32, onp.int32, onp.uint32]
for val_dtype in [onp.float32, onp.int32, onp.uint32]
for shape in [(3,), (5, 3)]
for axis in [-1, len(shape) - 1]
for rng in [jtu.rand_default()])
def testSortKeyValAgainstNumpy(self, shape, key_dtype, val_dtype, axis, rng):
if len(shape) > 1 and not FLAGS.jax_test_dut.startswith("tpu"):
msg = "sort_key_val only implemented for R1 non-TPU backends"
return absltest.unittest.skip(msg)
# This test relies on the property that wherever keys are tied, values are
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
perm_rng = onp.random.RandomState(0)
def args_maker():
flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
keys = perm_rng.permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
op = lambda ks, vs: lax.sort_key_val(ks, vs, axis)
numpy_op = lambda ks, vs: lax_reference.sort_key_val(ks, vs, axis)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
def testWhileWithTuple(self):
limit = 10
def loop_cond(state):
pos, _ = state
return lax.lt(pos, limit)
def loop_body(state):
pos, count = state
return (lax.add(pos, 1), lax.add(count, 1))
def loop(init):
result = lax._while_loop(loop_cond, loop_body, (init, 0))
_, count = result
return count
cloop = api.jit(loop)
self.assertEqual(loop(2), limit - 2)
self.assertEqual(cloop(2), limit - 2)
self.assertEqual(cloop(2), limit - 2)
self.assertEqual(cloop(3), limit - 3)
def testNestedWhile(self):
def outer_loop(num): # pylint: disable=missing-docstring
def cond_fun(state):
num, i, _ = state
return lax.lt(i, num)
def body_fun(state):
num, i, count = state
return (num, lax.add(i, 1), inner_loop(i, count))
init_val = (num, 0, 0)
_, i, count = lax._while_loop(cond_fun, body_fun, init_val)
return (i, count)
def inner_loop(i, count): # pylint: disable=missing-docstring
def cond_fun(state):
i, j, _ = state
return lax.le(j, i)
def body_fun(state):
i, j, count = state
return (i, lax.add(j, 1), lax.add(count, 1))
init_val = (i, 0, count)
_, _, count = lax._while_loop(cond_fun, body_fun, init_val)
return count
cloop = api.jit(outer_loop)
self.assertEqual(outer_loop(3), (3, 6))
self.assertEqual(cloop(3), (3, 6))
self.assertEqual(cloop(3), (3, 6))
self.assertEqual(cloop(2), (2, 3))
self.assertEqual(cloop(4), (4, 10))
def testNestedWhileWithDynamicUpdateSlice(self):
num = 5
def update_entry(arr, val, i, j):
val = lax.reshape(val, [1, 1])
return lax.dynamic_update_slice(arr, val, (i, j))
def outer_loop(arr): # pylint: disable=missing-docstring
def cond_fun(state):
i, num, _, _ = state
return lax.lt(i, num)
def body_fun(state):
i, num, arr, out = state
return (lax.add(i, 1), num, arr, inner_loop(i, arr, out))
out = onp.zeros(arr.shape, dtype=arr.dtype)
init_val = (0, num, arr, out)
_, _, _, out = lax._while_loop(cond_fun, body_fun, init_val)
return out
def inner_loop(i, arr, out): # pylint: disable=missing-docstring
def cond_fun(state):
i, j, _, _ = state
return lax.le(j, i)
def body_fun(state):
i, j, arr, out = state
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
arr_i_j = lax.dynamic_index_in_dim(arr_i, j, 0, False)
out = update_entry(out, arr_i_j, i, j)
return (i, lax.add(j, 1), arr, out)
init_val = (i, 0, arr, out)
_, _, _, out = lax._while_loop(cond_fun, body_fun, init_val)
return out
cloop = api.jit(outer_loop)
arr = npr.RandomState(0).randn(5, 5)
self.assertAllClose(outer_loop(arr), onp.tril(arr), check_dtypes=False)
self.assertAllClose(cloop(arr), onp.tril(arr), check_dtypes=False)
self.assertAllClose(cloop(arr), onp.tril(arr), check_dtypes=False)
def testLoopWithConjunctionCondition(self):
def sum_first_n(arr, num): # pylint: disable=missing-docstring
def cond_fun(state):
arr, num, i, _ = state
return lax.bitwise_and(lax.lt(i, num), lax.lt(i, arr.shape[0]))
def body_fun(state):
arr, num, i, total = state
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
return (arr, num, lax.add(i, 1), lax.add(total, arr_i))
init_val = (arr, num, 0, 0.)
_, _, _, total = lax._while_loop(cond_fun, body_fun, init_val)
return total
cfun = api.jit(sum_first_n)
x = npr.RandomState(0).randn(10)
for num in [0, 5, 10, 15]:
self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
def testForiLoopBasic(self):
def count(num):
def body_fun(i, tot):
return lax.add(tot, i)
return lax.fori_loop(0, num, body_fun, 0)
cfun = api.jit(count)
self.assertEqual(count(2), 1)
self.assertEqual(count(2), cfun(2))
self.assertEqual(count(3), 3)
self.assertEqual(count(3), cfun(3))
self.assertEqual(count(4), 6)
self.assertEqual(count(4), cfun(4))
def testForiLoopTupleState(self):
def sum_first_n(arr, num):
def body_fun(i, state):
arr, total = state
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
return (arr, lax.add(total, arr_i))
init_val = (arr, 0.)
_, total = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun,
init_val)
return total
cfun = api.jit(sum_first_n)
x = npr.RandomState(0).randn(10)
for num in [0, 5, 10, 15]:
self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
def testForiLoopDictState(self):
def sum_first_n(arr, num):
def body_fun(i, state):
arr, total = state['arr'], state['total']
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
return {'arr': arr, 'total': lax.add(total, arr_i)}
init_val = {'arr': arr, 'total': 0.}
out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
return out_val['total']
cfun = api.jit(sum_first_n)
x = npr.RandomState(0).randn(10)
for num in [0, 5, 10, 15]:
self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
def testForiLoopEmptyTupleInState(self):
def sum_first_n(arr, num):
def body_fun(i, state):
arr, total, _ = state
arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
return (arr, lax.add(total, arr_i), ())
init_val = (arr, 0., ())
_, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
return tot
cfun = api.jit(sum_first_n)
x = npr.RandomState(0).randn(10)
for num in [0, 5, 10, 15]:
self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]),
check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype)),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"rng": rng}
for lhs_shape, rhs_shape in [((3, 2), (2, 4)),
((5, 3, 2), (5, 2, 4)),
((1, 2, 2, 3), (1, 2, 3, 1))]
for dtype in float_dtypes
for rng in [jtu.rand_small()])
def testBatchMatMul(self, lhs_shape, rhs_shape, dtype, rng):
arg_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CompileAndCheck(lax.batch_matmul, arg_maker, check_dtypes=True)
def testCollapse(self):
@api.jit
def collapse_first_two(x):
return lax.collapse(x, 0, 2)
self.assertEqual((6,), collapse_first_two(onp.zeros((2, 3))).shape)
self.assertEqual((6, 4), collapse_first_two(onp.zeros((2, 3, 4))).shape)
self.assertEqual((2, 3, 4),
collapse_first_two(onp.zeros((1, 2, 3, 4))).shape)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
"shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, "rng": rng}
for dtype in all_dtypes
for shape, idxs, axes in [
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
]
for rng in [jtu.rand_default()])
def testIndexTake(self, shape, dtype, idxs, axes, rng):
rand_idxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs)
args_maker = lambda: [rng(shape, dtype), rand_idxs()]
fun = lambda src, idxs: lax.index_take(src, idxs, axes)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_dst_shape={}_idxs={}_axes={}".format(
jtu.format_shape_dtype_string(dst_shape, dtype), idxs, axes),
"dst_shape": dst_shape, "dtype": dtype, "idxs": idxs, "axes": axes,
"rng": rng}
for dtype in default_dtypes
for dst_shape, idxs, axes in [
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
]
for rng in [jtu.rand_default()])
def testIndexUntake(self, dst_shape, dtype, idxs, axes, rng):
# We call lax.index_take to get the shapes right
src_shape = lax.index_take(rng(dst_shape, dtype), idxs, axes).shape
ridxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs)
args_maker = lambda: [rng(src_shape, dtype), rng(dst_shape, dtype), ridxs()]
fun = lambda src, dst, idxs: lax.index_untake(src, dst, idxs, axes)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
GradTestSpec = collections.namedtuple(
"GradTestSpec", ["op", "nargs", "order", "rng", "dtypes"])
LAX_GRAD_OPS = [
GradTestSpec(lax.neg, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.floor, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64]),
GradTestSpec(lax.ceil, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64]),
GradTestSpec(lax.round, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64]),
# GradTestSpec(lax.rem, nargs=2, order=2, rng=jtu.rand_default(),
# dtypes=[onp.float64]), # TODO(mattjj): enable
GradTestSpec(lax.exp, nargs=1, order=2, rng=jtu.rand_small(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.expm1, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.log, nargs=1, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.log1p, nargs=1, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.tanh, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.sin, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64]),
GradTestSpec(lax.cos, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64]),
GradTestSpec(lax.erf, nargs=1, order=2, rng=jtu.rand_small(),
dtypes=[onp.float64]),
GradTestSpec(lax.erfc, nargs=1, order=2, rng=jtu.rand_small(),
dtypes=[onp.float64]),
GradTestSpec(lax.erf_inv, nargs=1, order=2, rng=jtu.rand_small(),
dtypes=[onp.float64]),
# GradTestSpec(lax.lgamma, nargs=1, order=2, rng=jtu.rand_small(),
# dtypes=[onp.float64]), # TODO(mattjj): enable
GradTestSpec(lax.real, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.complex64]),
GradTestSpec(lax.imag, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.complex64]),
# GradTestSpec(lax.complex, nargs=2, order=2, rng=jtu.rand_default(),
# dtypes=[onp.float32]), # TODO(mattjj): enable
GradTestSpec(lax.conj, nargs=1, order=2, rng=jtu.rand_default(),
dtypes=[onp.float32, onp.complex64]),
GradTestSpec(lax.abs, nargs=1, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.pow, nargs=2, order=2, rng=jtu.rand_positive(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.add, nargs=2, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.sub, nargs=2, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.mul, nargs=2, order=2, rng=jtu.rand_default(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.div, nargs=2, order=1, rng=jtu.rand_not_small(),
dtypes=[onp.float64, onp.complex64]),
GradTestSpec(lax.max, nargs=2, order=2, rng=jtu.rand_some_equal(),
dtypes=[onp.float64]),
GradTestSpec(lax.min, nargs=2, order=2, rng=jtu.rand_some_equal(),
dtypes=[onp.float64]),
]
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
# TODO(mattjj,dougalm): add higher-order check
default_tol = 1e-6 if FLAGS.jax_enable_x64 else 1e-2
atol = atol or default_tol
rtol = rtol or default_tol
eps = eps or default_tol
jtu.check_jvp(f, partial(api.jvp, f), args, atol, rtol, eps)
jtu.check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps)
def check_grads_bilinear(f, args, order, atol=None, rtol=None):
# Can use large eps to make up for numerical inaccuracies since the op is
# bilinear (relying on the fact that we only check one arg at a time)
lhs, rhs = args
check_grads(lambda lhs: f(lhs, rhs), (lhs,), order, atol, rtol, eps=1.)
check_grads(lambda rhs: f(lhs, rhs), (rhs,), order, atol, rtol, eps=1.)
class LaxAutodiffTest(jtu.JaxTestCase):
@parameterized.named_parameters(
{"testcase_name": jtu.format_test_name_suffix(
rec.op.__name__, shapes, itertools.repeat(dtype)),
"op": rec.op, "rng": rec.rng, "shapes": shapes, "dtype": dtype,
"order": rec.order}
for rec in LAX_GRAD_OPS
for shape_group in compatible_shapes
for shapes in CombosWithReplacement(shape_group, rec.nargs)
for dtype in rec.dtypes
)
def testOpGrad(self, op, rng, shapes, dtype, order):
if FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu"):
if dtype is onp.complex64:
return absltest.unittest.skip("complex grads unimplemented on tpu")
if op is lax.pow:
return absltest.unittest.skip("pow grad imprecise on tpu")
tol = 1e-1 if num_float_bits(dtype) == 32 else None
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_from_dtype={}_to_dtype={}".format(
from_dtype, to_dtype),
"from_dtype": from_dtype, "to_dtype": to_dtype, "rng": rng}
for from_dtype, to_dtype in itertools.product(
[onp.float32, onp.float64], repeat=2)
for rng in [jtu.rand_default()])
def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng):
args = (rng((2, 3), from_dtype),)
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
check_grads(convert_element_type, args, 1, 1e-3, 1e-3, 1e-3)
@parameterized.named_parameters(
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
jtu.format_shape_dtype_string(min_shape, dtype),
jtu.format_shape_dtype_string(operand_shape, dtype),
jtu.format_shape_dtype_string(max_shape, dtype)),
"min_shape": min_shape, "operand_shape": operand_shape,
"max_shape": max_shape, "dtype": dtype, "rng": rng}
for min_shape, operand_shape, max_shape in [
[(), (), ()],
[(), (2, 3), ()],
[(2, 3), (2, 3), (2, 3)],
]
for dtype in float_dtypes
for rng in [jtu.rand_default()])
def testClampGrad(self, min_shape, operand_shape, max_shape, dtype, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
shapes = [min_shape, operand_shape, max_shape]
min, operand, max = (rng(shape, dtype) for shape in shapes)
min, max = onp.minimum(min, max), onp.maximum(min, max) # broadcast
check_grads(lax.clamp, (min, operand, max), 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
num_arrs),
"dim": dim, "base_shape": base_shape, "dtype": dtype,
"num_arrs": num_arrs, "rng": rng}
for num_arrs in [3]
for dtype in float_dtypes
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for dim in range(len(base_shape))
for rng in [jtu.rand_default()])
def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
operands = tuple(rng(shape, dtype) for shape in shapes)
concatenate = lambda *args: lax.concatenate(args, dim)
check_grads(concatenate, operands, 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "rng": rng,}
for lhs_shape, rhs_shape, all_strides in itertools.chain(
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
for b, i, j in itertools.product([2, 3], repeat=3)],
[((4, 2, 1), (3, 2, 1), [(1,)])])
for strides in all_strides
for dtype in [onp.float32]
for padding in ["VALID", "SAME"]
for rng in [jtu.rand_small()])
@jtu.skip_on_devices("tpu")
def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng):
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
conv = partial(lax.conv, window_strides=strides, padding=padding)
check_grads_bilinear(conv, (lhs, rhs), order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
"rhs_dilation={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dil, rhs_dil),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
"rhs_dil": rhs_dil, "rng": rng}
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
itertools.chain(
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
[((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
[(1, 1), (2, 1)], [(1, 1)])
for b, i, j in itertools.product([2, 3], repeat=3)],
[((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)],
[(1,), (2,)], [(1,), (2,)])])
for strides in all_strides
for rhs_dil in rhs_dils
for lhs_dil in lhs_dils
for dtype in [onp.float32]
for padding in all_pads
for rng in [jtu.rand_small()])
@jtu.skip_on_devices("tpu")
def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dil, rhs_dil, rng):
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
conv = partial(lax.conv_with_general_padding, window_strides=strides,
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil)
check_grads_bilinear(conv, (lhs, rhs), order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
"rhs_dilation={}_dims={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums)),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
"rhs_dil": rhs_dil, "rng": rng, "dimension_numbers": dim_nums,
"perms": perms}
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [
((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
[((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
[(1, 1), (2, 1)], [(1, 1)])
for b, i, j in itertools.product([1, 2], repeat=3)]
for strides in all_strides
for rhs_dil in rhs_dils
for lhs_dil in lhs_dils
for dtype in [onp.float32]
for padding in all_pads
for rng in [jtu.rand_default()]
for dim_nums, perms in [
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0]))])
@jtu.skip_on_devices("tpu")
def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dil, rhs_dil, dimension_numbers,
perms, rng):
tol = 1e-1 if onp.finfo(dtype).bits == 32 else 1e-3
lhs_perm, rhs_perm = perms # permute to compatible shapes
lhs = onp.transpose(rng(lhs_shape, dtype), lhs_perm)
rhs = onp.transpose(rng(rhs_shape, dtype), rhs_perm)
conv = partial(lax.conv_general_dilated, window_strides=strides,
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
dimension_numbers=dimension_numbers)
check_grads_bilinear(conv, (lhs, rhs), order=2, atol=tol, rtol=tol)
@parameterized.named_parameters(
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype)),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"rng": jtu.rand_default()}
for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)]
for dtype in float_dtypes)
def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng):
tol = 1e-1 if num_float_bits(dtype) == 32 else 1e-3
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
check_grads_bilinear(lax.dot, (lhs, rhs), order=2, atol=tol, rtol=tol)
@parameterized.named_parameters(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
dimension_numbers),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"dimension_numbers": dimension_numbers, "rng": jtu.rand_small()}
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 2), (2, 4), (([1], [0]), ([], []))),
((3, 5), (2, 5), (([1], [1]), ([], []))),
((5, 3), (5, 2), (([0], [0]), ([], []))),
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
]
for dtype in float_dtypes)
@jtu.skip_on_devices("tpu")
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
dimension_numbers, rng):
tol = 1e-1 if onp.finfo(dtype).bits == 32 else 1e-2
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers)
check_grads_bilinear(dot_general, (lhs, rhs), order=2, atol=tol, rtol=tol)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
shape, onp.dtype(dtype).name, broadcast_sizes),
"shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
"rng": rng}
for shape in [(), (2, 3)]
for dtype in float_dtypes
for broadcast_sizes in [(), (2,), (1, 2)]
for rng in [jtu.rand_default()])
def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
args = (rng(shape, dtype),)
broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
check_grads(broadcast, args, 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
jtu.format_shape_dtype_string(inshape, dtype),
outshape, broadcast_dimensions),
"inshape": inshape, "dtype": dtype, "outshape": outshape,
"dimensions": broadcast_dimensions, "rng": rng}
for inshape, outshape, broadcast_dimensions in [
([2], [2, 2], [0]),
([2], [2, 2], [1]),
([2], [2, 3], [0]),
([], [2, 3], []),
]
for dtype in float_dtypes
for rng in [jtu.rand_default()])
def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(inshape, dtype)
broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
check_grads(broadcast_in_dim, (operand,), 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_inshape={}_outshape={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype)),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
"rng": rng}
for dtype in float_dtypes
for arg_shape, out_shape in [
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
]
for rng in [jtu.rand_default()])
def testReshapeGrad(self, arg_shape, out_shape, dtype, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(arg_shape, dtype)
reshape = lambda x: lax.reshape(x, out_shape)
check_grads(reshape, (operand,), 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_inshape={}_pads={}"
.format(jtu.format_shape_dtype_string(shape, dtype), pads),
"shape": shape, "dtype": dtype, "pads": pads, "rng": jtu.rand_small()}
for shape in [(2, 3)]
for dtype in float_dtypes
for pads in [[(1, 2, 1), (0, 1, 0)]])
def testPadGrad(self, shape, dtype, pads, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(shape, dtype)
pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
check_grads(pad, (operand,), 2, tol, tol, tol)
operand = rng(shape, dtype)
padding_value = onp.array(0., dtype)
pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
check_grads(pad, (operand, padding_value), 2, tol, tol, tol)
def testReverseGrad(self):
rev = lambda operand: lax.rev(operand, dimensions)
dimensions = [0]
check_grads(rev, (onp.array([3., 2., 1.]),), 2)
dimensions = [0, 1]
check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2)
@parameterized.named_parameters(
{"testcase_name": "_predshape={}_argshapes={}".format(
jtu.format_shape_dtype_string(pred_shape, onp.bool_),
jtu.format_shape_dtype_string(arg_shape, dtype)),
"pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype,
"rng": rng}
for arg_shape in [(), (3,), (2, 3)]
for pred_shape in ([(), arg_shape] if arg_shape else [()])
for dtype in float_dtypes
for rng in [jtu.rand_default()])
def testSelectGrad(self, pred_shape, arg_shape, dtype, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
pred = rng(pred_shape, onp.bool_)
on_true = rng(arg_shape, dtype)
on_false = rng(arg_shape, dtype)
select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
check_grads(select, (on_true, on_false), 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name":
"_shape={}_start_indices={}_limit_indices={}_strides={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, limit_indices, strides),
"shape": shape, "dtype": dtype, "starts": start_indices,
"limits": limit_indices, "strides": strides, "rng": rng}
for shape, start_indices, limit_indices, strides in [
[(3,), (1,), (2,), None],
[(7,), (4,), (7,), None],
[(5,), (1,), (5,), (2,)],
[(8,), (1,), (6,), (2,)],
[(5, 3), (1, 1), (3, 2), None],
[(5, 3), (1, 1), (3, 1), None],
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
[(5, 3), (1, 1), (2, 1), (1, 1)],
[(5, 3), (1, 1), (5, 3), (2, 1)],
]
for dtype in float_dtypes
for rng in [jtu.rand_default()])
def testSliceGrad(self, shape, dtype, starts, limits, strides, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(shape, dtype)
slice = lambda x: lax.slice(x, starts, limits, strides)
check_grads(slice, (operand,), 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, size_indices),
"shape": shape, "dtype": dtype, "start_indices": start_indices,
"size_indices": size_indices, "rng": rng}
for shape, start_indices, size_indices in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
for dtype in float_dtypes
for rng in [jtu.rand_default()])
def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices,
rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(shape, dtype)
dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
check_grads(dynamic_slice, (operand,), 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype),
start_indices, update_shape),
"shape": shape, "dtype": dtype, "start_indices": start_indices,
"update_shape": update_shape, "rng": rng}
for shape, start_indices, update_shape in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
for dtype in float_dtypes
for rng in [jtu.rand_default()])
def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices,
update_shape, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(shape, dtype)
update = rng(update_shape, dtype)
start_indices = onp.array(start_indices)
dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
check_grads(dus, (operand, update), 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_perm={}".format(
jtu.format_shape_dtype_string(shape, dtype), perm),
"shape": shape, "dtype": dtype, "perm": perm, "rng": rng}
for shape, perm in [
[(3, 4), (1, 0)],
[(3, 4), (0, 1)],
[(3, 4, 5), (2, 1, 0)],
[(3, 4, 5), (1, 0, 2)],
]
for dtype in float_dtypes
for rng in [jtu.rand_default()])
def testTransposeGrad(self, shape, dtype, perm, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(shape, dtype)
transpose = lambda x: lax.transpose(x, perm)
check_grads(transpose, (operand,), 2, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_op={}_inshape={}_reducedims={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
"dims": dims, "rng": rng}
for init_val, op, dtypes in [
(0, lax.add, float_dtypes),
(-onp.inf, lax.max, float_dtypes),
(onp.inf, lax.min, float_dtypes),
]
for dtype in dtypes
for shape, dims in [
[(3, 4, 5), (0,)],
[(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)],
[(3, 4, 5), (0, 1, 2)]
]
for rng in [jtu.rand_small()])
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(shape, dtype)
init_val = onp.asarray(init_val, dtype=dtype)
reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
check_grads(reduce, (operand,), 1, tol, tol, tol)
@parameterized.named_parameters(
{"testcase_name": "_op={}_dtype={}_padding={}"
.format(op.__name__, onp.dtype(dtype).name, padding),
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
"rng": rng}
for init_val, op, dtypes, rng in [
(0, lax.add, [onp.float32], jtu.rand_small()),
(-onp.inf, lax.max, [onp.float32], jtu.rand_default()),
(onp.inf, lax.min, [onp.float32], jtu.rand_default()),
]
for dtype in dtypes
for padding in ["VALID", "SAME"]
for rng in [jtu.rand_default()])
def testReduceWindowGrad(self, op, init_val, dtype, padding, rng):
init_val = onp.asarray(init_val, dtype=dtype)
# We need this conditional and the corresponding loop logic to be in the
# test method, rather than at the parameterized test level, because it
# depends on FLAGS for the device under test.
if FLAGS.jax_test_dut == "tpu":
all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]
else:
all_configs = itertools.chain(
itertools.product(
[(4, 6)], # shapes
[(2, 1), (1, 2)], # window_dimensions
[(1, 1), (2, 1), (1, 2)] # strides
),
itertools.product(
[(3, 2, 4, 6)], # shapes
[(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions
[(1, 2, 2, 1), (1, 1, 1, 1)]), # strides
)
def fun(operand):
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
# pylint: disable=cell-var-from-loop
for shape, dims, strides in all_configs:
operand = rng(shape, dtype)
if op is lax.add:
check_grads(fun, (operand,), 1, 1e-2, 1e-2, 1e-2)
else:
# this test can fail if there are duplicates in operand
self.assertEqual(onp.unique(operand).size, operand.size,
msg="test requires operand elements to be unique.")
jtu.check_vjp(fun, partial(api.vjp, fun), (operand,),
1e-2, 1e-2, 1e-2)
# pylint: enable=cell-var-from-loop
# TODO(b/205052657): enable more tests when supported
@parameterized.named_parameters(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
"rng": rng, "shape": shape, "dtype": dtype, "axis": axis}
for dtype in [onp.float32]
for shape in [(5,), (5, 7)]
for axis in [len(shape) - 1]
for rng in [jtu.rand_default()])
def testSortGrad(self, shape, dtype, axis, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(shape, dtype)
sort = lambda x: lax.sort(x, axis)
check_grads(sort, (operand,), 2, tol, tol, tol)
# TODO(b/205052657): enable more tests when supported
@parameterized.named_parameters(
{"testcase_name": "_keyshape={}_valshape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, key_dtype),
jtu.format_shape_dtype_string(shape, val_dtype),
axis),
"rng": rng, "shape": shape,
"key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis}
for key_dtype in [onp.float32]
for val_dtype in [onp.float32]
for shape in [(3,), (5, 3)]
for axis in [len(shape) - 1]
for rng in [jtu.rand_default()])
def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, rng):
# This test relies on the property that wherever keys are tied, values are
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
perm_rng = onp.random.RandomState(0)
def args_maker():
flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
keys = perm_rng.permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
keys, values = args_maker()
fun = lambda keys, values: lax.sort_key_val(keys, values, axis)
check_grads(fun, (keys, values), 2, 1e-2, 1e-2, 1e-2)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_idxs={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
"shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, "rng": rng}
for dtype in float_dtypes
for shape, idxs, axes in [
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
]
for rng in [jtu.rand_default()])
def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng):
idxs = tuple(rng(e.shape, e.dtype) for e in idxs)
src = rng(shape, dtype)
index_take = lambda src: lax.index_take(src, idxs, axes)
check_grads(index_take, (src,), 2, 1e-2, 1e-2, 1e-2)
@parameterized.named_parameters(
{"testcase_name": "_dst_shape={}_idxs={}_axes={}".format(
jtu.format_shape_dtype_string(dst_shape, dtype), idxs, axes),
"dst_shape": dst_shape, "dtype": dtype, "idxs": idxs, "axes": axes,
"rng": rng}
for dtype in float_dtypes
for dst_shape, idxs, axes in [
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
]
for rng in [jtu.rand_default()])
def testIndexUntakeGrad(self, dst_shape, dtype, idxs, axes, rng):
# We call lax.index_take to get the shapes right
src_shape = lax.index_take(rng(dst_shape, dtype), idxs, axes).shape
idxs = tuple(rng(e.shape, e.dtype) for e in idxs)
src = rng(src_shape, dtype)
dst = rng(dst_shape, dtype)
index_untake = lambda src, dst: lax.index_untake(src, dst, idxs, axes)
check_grads(index_untake, (src, dst), 2, 1e-2, 1e-2, 1e-2)
if __name__ == '__main__':
absltest.main()