Add support for higher derivatives of reduce-window-min/max at reduced precision. On CPU/GPU this means support for float64 derivatives, and on TPU this means support for float32 derivatives.

Warn if we are forced to be imprecise.
This commit is contained in:
Peter Hawkins 2019-06-28 20:27:10 -04:00
parent acda3f398b
commit db369091a2
2 changed files with 55 additions and 19 deletions

View File

@ -3612,11 +3612,29 @@ _UINT_DTYPES = {
64: onp.uint64,
}
def _select_and_gather_add_pair_reducer(dtype, select_prim):
bits = onp.finfo(dtype).bits
_float_bitwidths = {
xla_client.PrimitiveType.BF16: 16,
xla_client.PrimitiveType.F16: 16,
xla_client.PrimitiveType.F32: 32,
xla_client.PrimitiveType.F64: 64,
}
_select_and_gather_add_reduction_types = {
xla_client.PrimitiveType.BF16: xla_client.PrimitiveType.BF16,
xla_client.PrimitiveType.F16: xla_client.PrimitiveType.F16,
xla_client.PrimitiveType.F32: xla_client.PrimitiveType.F32,
xla_client.PrimitiveType.F64: xla_client.PrimitiveType.F32,
}
_select_and_gather_add_tpu_reduction_types = {
xla_client.PrimitiveType.BF16: xla_client.PrimitiveType.BF16,
xla_client.PrimitiveType.F32: xla_client.PrimitiveType.BF16,
}
def _select_and_gather_add_pair_reducer(etype, select_prim):
bits = _float_bitwidths[etype]
pair_uint_dtype = _UINT_DTYPES[bits * 2]
uint_etype = xla_bridge.dtype_to_etype_exact(_UINT_DTYPES[bits])
etype = xla_bridge.dtype_to_etype_exact(dtype)
c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer")
x = c.ParameterWithShape(
@ -3639,41 +3657,56 @@ def _select_and_gather_add_pair_reducer(dtype, select_prim):
def _select_and_gather_add_translation(
c, tangents, operand, select_prim, window_dimensions, window_strides,
padding):
padding, reduction_types=None):
reduction_types = reduction_types or _select_and_gather_add_reduction_types
# XLA doesn't yet implement ReduceWindow on tuples (Google bug b/73062247), so
# we implement a pair-wise ReduceWindow by packing two k-bit values into
# 2k-bit unsigned integer using bit tricks. This will only work for <= 32-bit
# inputs (since we don't have 128-bit integer types).
dtype = c.GetShape(operand).numpy_dtype()
bits = onp.finfo(dtype).bits
if bits > 32:
raise NotImplementedError(
"select_and_gather_add is not implemented for type larger than 32 bits")
etype = xla_bridge.dtype_to_etype(dtype)
uint_etype = xla_bridge.dtype_to_etype(_UINT_DTYPES[bits])
shape = c.GetShape(operand)
etype = shape.xla_element_type()
reduction_etype = reduction_types.get(etype, None)
if reduction_etype is None:
msg = "Unsupported type for select_and_gather_add: {}"
raise ValueError(msg.format(etype))
if reduction_etype != etype:
warnings.warn("Using reduced precision for gradient of reduce-window "
"min/max operator. This is likely from a second or "
"higher derivative of a max-pooling operation and is to work"
"around a missing XLA feature.")
bits = _float_bitwidths[reduction_etype]
uint_etype = xla_bridge.dtype_to_etype_exact(_UINT_DTYPES[bits])
pair_uint_dtype = _UINT_DTYPES[bits * 2]
pair_uint_etype = xla_bridge.dtype_to_etype_exact(pair_uint_dtype)
operand = c.ConvertElementType(operand, reduction_etype)
operand = c.BitcastConvertType(operand, uint_etype)
tangents = c.BitcastConvertType(tangents, uint_etype)
operand = c.ConvertElementType(operand, pair_uint_etype)
tangents = c.ConvertElementType(tangents, reduction_etype)
tangents = c.BitcastConvertType(tangents, uint_etype)
tangents = c.ConvertElementType(tangents, pair_uint_etype)
operand = c.ShiftLeft(
operand, c.Constant(pair_uint_dtype(bits), canonicalize_types=False))
assert select_prim is ge_p or select_prim is le_p
init = -onp.inf if select_prim is ge_p else onp.inf
init = c.BitcastConvertType(c.Constant(dtype.type(init)), uint_etype)
init = c.Constant(shape.numpy_dtype().type(init))
init = c.ConvertElementType(init, reduction_etype)
init = c.BitcastConvertType(init, uint_etype)
init = c.ConvertElementType(init, pair_uint_etype)
init = c.ShiftLeft(
init, c.Constant(pair_uint_dtype(bits), canonicalize_types=False))
xla_computation = _select_and_gather_add_pair_reducer(dtype, select_prim)
xla_computation = _select_and_gather_add_pair_reducer(reduction_etype,
select_prim)
out = c.ReduceWindow(c.Or(operand, tangents), init,
xla_computation, window_dimensions, window_strides,
padding)
out = c.ConvertElementType(out, uint_etype)
return c.BitcastConvertType(out, etype)
out = c.BitcastConvertType(out, reduction_etype)
return c.ConvertElementType(out, etype)
def _select_and_gather_add_jvp(
primals, tangents, select_prim, window_dimensions, window_strides,
@ -3706,6 +3739,9 @@ select_and_gather_add_p = standard_primitive(
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
ad.primitive_transposes[select_and_gather_add_p] = \
_select_and_gather_add_transpose
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
_select_and_gather_add_translation,
reduction_types=_select_and_gather_add_tpu_reduction_types)
sort_shape = lambda operand, dimension: operand.shape

View File

@ -2029,9 +2029,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
"rng": rng}
for init_val, op, dtypes, rng in [
(0, lax.add, [onp.float32], jtu.rand_small()),
(-onp.inf, lax.max, [onp.float32], jtu.rand_default()),
(onp.inf, lax.min, [onp.float32], jtu.rand_default()),
(0, lax.add, float_dtypes, jtu.rand_small()),
(-onp.inf, lax.max, float_dtypes, jtu.rand_default()),
(onp.inf, lax.min, float_dtypes, jtu.rand_default()),
]
for dtype in dtypes
for padding in ["VALID", "SAME"]
@ -2045,7 +2045,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
# TODO(b/31565929): enable when fixed.
if FLAGS.jax_test_dut == "tpu" and op is not lax.add:
all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]
test_gradients = False # TODO(b/73062247): need variadic reduce-window.
test_gradients = True # TODO(b/73062247): need variadic reduce-window.
else:
all_configs = itertools.chain(
itertools.product(