mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Simplify reduce-precision logic.
Enable TPU gradient tests only up to order 1. The first-order JVP of reduce-window tests select_and_scatter_add, which is the part changed by this PR.
This commit is contained in:
parent
40560d2c9a
commit
165df6204b
@ -3660,14 +3660,13 @@ def _select_and_gather_add_translation(
|
||||
"min/max operator. This is likely from a second or "
|
||||
"higher derivative of a max-pooling operation and is to work"
|
||||
"around a missing XLA support for pair-reductions.")
|
||||
if etype == xla_client.PrimitiveType.F32:
|
||||
nexp, nmant = 8, 7 # bfloat16 precision.
|
||||
elif etype == xla_client.PrimitiveType.F64:
|
||||
nexp, nmant = onp.finfo(onp.float32).nexp, onp.finfo(onp.float32).nmant
|
||||
r_nbits = nbits // 2
|
||||
# Drop/round the bottom mantissa bits.
|
||||
nexp = onp.finfo(dtype).nexp
|
||||
nmant = r_nbits - nexp - 1
|
||||
|
||||
double_word_dtype = word_dtype = _UINT_DTYPES[nbits]
|
||||
word_type = xla_bridge.dtype_to_etype_exact(word_dtype)
|
||||
r_nbits = nbits // 2
|
||||
|
||||
# Packs two values into a tuple.
|
||||
def pack(a, b):
|
||||
|
@ -2045,7 +2045,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# TODO(b/31565929): enable when fixed.
|
||||
if FLAGS.jax_test_dut == "tpu" and op is not lax.add:
|
||||
all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]
|
||||
test_gradients = True # TODO(b/73062247): need variadic reduce-window.
|
||||
|
||||
# TODO(b/73062247): need variadic reduce-window for better precision.
|
||||
gradient_order = 1
|
||||
else:
|
||||
all_configs = itertools.chain(
|
||||
itertools.product(
|
||||
@ -2058,22 +2060,19 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
[(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)]), # strides
|
||||
)
|
||||
test_gradients = True
|
||||
gradient_order = 3
|
||||
|
||||
def fun(operand):
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
for shape, dims, strides in all_configs:
|
||||
operand = rng(shape, dtype)
|
||||
if op is not lax.add:
|
||||
# this test can fail if there are duplicates in operand
|
||||
self.assertEqual(onp.unique(operand).size, operand.size,
|
||||
msg="test requires operand elements to be unique.")
|
||||
jtu.check_vjp(fun, partial(api.vjp, fun), (operand,), 1e-2, 1e-2, 1e-2)
|
||||
if test_gradients:
|
||||
check_grads(fun, (operand,), 3, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], 1e-2, 1e-2,
|
||||
1e-2)
|
||||
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user