Simplify reduce-precision logic.

Enable TPU gradient tests only up to order 1. The first-order JVP of reduce-window tests select_and_scatter_add, which is the part changed by this PR.
This commit is contained in:
Peter Hawkins 2019-07-02 11:34:49 -04:00
parent 40560d2c9a
commit 165df6204b
2 changed files with 10 additions and 12 deletions

View File

@ -3660,14 +3660,13 @@ def _select_and_gather_add_translation(
"min/max operator. This is likely from a second or "
"higher derivative of a max-pooling operation and is to work"
"around a missing XLA support for pair-reductions.")
if etype == xla_client.PrimitiveType.F32:
nexp, nmant = 8, 7 # bfloat16 precision.
elif etype == xla_client.PrimitiveType.F64:
nexp, nmant = onp.finfo(onp.float32).nexp, onp.finfo(onp.float32).nmant
r_nbits = nbits // 2
# Drop/round the bottom mantissa bits.
nexp = onp.finfo(dtype).nexp
nmant = r_nbits - nexp - 1
double_word_dtype = word_dtype = _UINT_DTYPES[nbits]
word_type = xla_bridge.dtype_to_etype_exact(word_dtype)
r_nbits = nbits // 2
# Packs two values into a tuple.
def pack(a, b):

View File

@ -2045,7 +2045,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
# TODO(b/31565929): enable when fixed.
if FLAGS.jax_test_dut == "tpu" and op is not lax.add:
all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]
test_gradients = True # TODO(b/73062247): need variadic reduce-window.
# TODO(b/73062247): need variadic reduce-window for better precision.
gradient_order = 1
else:
all_configs = itertools.chain(
itertools.product(
@ -2058,22 +2060,19 @@ class LaxAutodiffTest(jtu.JaxTestCase):
[(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions
[(1, 2, 2, 1), (1, 1, 1, 1)]), # strides
)
test_gradients = True
gradient_order = 3
def fun(operand):
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
# pylint: disable=cell-var-from-loop
for shape, dims, strides in all_configs:
operand = rng(shape, dtype)
if op is not lax.add:
# this test can fail if there are duplicates in operand
self.assertEqual(onp.unique(operand).size, operand.size,
msg="test requires operand elements to be unique.")
jtu.check_vjp(fun, partial(api.vjp, fun), (operand,), 1e-2, 1e-2, 1e-2)
if test_gradients:
check_grads(fun, (operand,), 3, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)
# pylint: enable=cell-var-from-loop
check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], 1e-2, 1e-2,
1e-2)
# TODO(b/205052657): enable more tests when supported
@parameterized.named_parameters(jtu.cases_from_list(