mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add support for higher derivatives of reduce-window-min/max at reduced precision. On CPU/GPU this means support for float64 derivatives, and on TPU this means support for float32 derivatives.
Warn if we are forced to be imprecise.
This commit is contained in:
parent
acda3f398b
commit
db369091a2
@ -3612,11 +3612,29 @@ _UINT_DTYPES = {
|
||||
64: onp.uint64,
|
||||
}
|
||||
|
||||
def _select_and_gather_add_pair_reducer(dtype, select_prim):
|
||||
bits = onp.finfo(dtype).bits
|
||||
_float_bitwidths = {
|
||||
xla_client.PrimitiveType.BF16: 16,
|
||||
xla_client.PrimitiveType.F16: 16,
|
||||
xla_client.PrimitiveType.F32: 32,
|
||||
xla_client.PrimitiveType.F64: 64,
|
||||
}
|
||||
|
||||
_select_and_gather_add_reduction_types = {
|
||||
xla_client.PrimitiveType.BF16: xla_client.PrimitiveType.BF16,
|
||||
xla_client.PrimitiveType.F16: xla_client.PrimitiveType.F16,
|
||||
xla_client.PrimitiveType.F32: xla_client.PrimitiveType.F32,
|
||||
xla_client.PrimitiveType.F64: xla_client.PrimitiveType.F32,
|
||||
}
|
||||
|
||||
_select_and_gather_add_tpu_reduction_types = {
|
||||
xla_client.PrimitiveType.BF16: xla_client.PrimitiveType.BF16,
|
||||
xla_client.PrimitiveType.F32: xla_client.PrimitiveType.BF16,
|
||||
}
|
||||
|
||||
def _select_and_gather_add_pair_reducer(etype, select_prim):
|
||||
bits = _float_bitwidths[etype]
|
||||
pair_uint_dtype = _UINT_DTYPES[bits * 2]
|
||||
uint_etype = xla_bridge.dtype_to_etype_exact(_UINT_DTYPES[bits])
|
||||
etype = xla_bridge.dtype_to_etype_exact(dtype)
|
||||
|
||||
c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer")
|
||||
x = c.ParameterWithShape(
|
||||
@ -3639,41 +3657,56 @@ def _select_and_gather_add_pair_reducer(dtype, select_prim):
|
||||
|
||||
def _select_and_gather_add_translation(
|
||||
c, tangents, operand, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
padding, reduction_types=None):
|
||||
reduction_types = reduction_types or _select_and_gather_add_reduction_types
|
||||
# XLA doesn't yet implement ReduceWindow on tuples (Google bug b/73062247), so
|
||||
# we implement a pair-wise ReduceWindow by packing two k-bit values into
|
||||
# 2k-bit unsigned integer using bit tricks. This will only work for <= 32-bit
|
||||
# inputs (since we don't have 128-bit integer types).
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
bits = onp.finfo(dtype).bits
|
||||
if bits > 32:
|
||||
raise NotImplementedError(
|
||||
"select_and_gather_add is not implemented for type larger than 32 bits")
|
||||
etype = xla_bridge.dtype_to_etype(dtype)
|
||||
uint_etype = xla_bridge.dtype_to_etype(_UINT_DTYPES[bits])
|
||||
shape = c.GetShape(operand)
|
||||
etype = shape.xla_element_type()
|
||||
reduction_etype = reduction_types.get(etype, None)
|
||||
if reduction_etype is None:
|
||||
msg = "Unsupported type for select_and_gather_add: {}"
|
||||
raise ValueError(msg.format(etype))
|
||||
|
||||
if reduction_etype != etype:
|
||||
warnings.warn("Using reduced precision for gradient of reduce-window "
|
||||
"min/max operator. This is likely from a second or "
|
||||
"higher derivative of a max-pooling operation and is to work"
|
||||
"around a missing XLA feature.")
|
||||
|
||||
bits = _float_bitwidths[reduction_etype]
|
||||
uint_etype = xla_bridge.dtype_to_etype_exact(_UINT_DTYPES[bits])
|
||||
pair_uint_dtype = _UINT_DTYPES[bits * 2]
|
||||
pair_uint_etype = xla_bridge.dtype_to_etype_exact(pair_uint_dtype)
|
||||
|
||||
operand = c.ConvertElementType(operand, reduction_etype)
|
||||
operand = c.BitcastConvertType(operand, uint_etype)
|
||||
tangents = c.BitcastConvertType(tangents, uint_etype)
|
||||
operand = c.ConvertElementType(operand, pair_uint_etype)
|
||||
tangents = c.ConvertElementType(tangents, reduction_etype)
|
||||
tangents = c.BitcastConvertType(tangents, uint_etype)
|
||||
tangents = c.ConvertElementType(tangents, pair_uint_etype)
|
||||
operand = c.ShiftLeft(
|
||||
operand, c.Constant(pair_uint_dtype(bits), canonicalize_types=False))
|
||||
|
||||
assert select_prim is ge_p or select_prim is le_p
|
||||
init = -onp.inf if select_prim is ge_p else onp.inf
|
||||
init = c.BitcastConvertType(c.Constant(dtype.type(init)), uint_etype)
|
||||
init = c.Constant(shape.numpy_dtype().type(init))
|
||||
init = c.ConvertElementType(init, reduction_etype)
|
||||
init = c.BitcastConvertType(init, uint_etype)
|
||||
init = c.ConvertElementType(init, pair_uint_etype)
|
||||
init = c.ShiftLeft(
|
||||
init, c.Constant(pair_uint_dtype(bits), canonicalize_types=False))
|
||||
|
||||
xla_computation = _select_and_gather_add_pair_reducer(dtype, select_prim)
|
||||
xla_computation = _select_and_gather_add_pair_reducer(reduction_etype,
|
||||
select_prim)
|
||||
out = c.ReduceWindow(c.Or(operand, tangents), init,
|
||||
xla_computation, window_dimensions, window_strides,
|
||||
padding)
|
||||
out = c.ConvertElementType(out, uint_etype)
|
||||
return c.BitcastConvertType(out, etype)
|
||||
out = c.BitcastConvertType(out, reduction_etype)
|
||||
return c.ConvertElementType(out, etype)
|
||||
|
||||
def _select_and_gather_add_jvp(
|
||||
primals, tangents, select_prim, window_dimensions, window_strides,
|
||||
@ -3706,6 +3739,9 @@ select_and_gather_add_p = standard_primitive(
|
||||
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
|
||||
ad.primitive_transposes[select_and_gather_add_p] = \
|
||||
_select_and_gather_add_transpose
|
||||
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
|
||||
_select_and_gather_add_translation,
|
||||
reduction_types=_select_and_gather_add_tpu_reduction_types)
|
||||
|
||||
|
||||
sort_shape = lambda operand, dimension: operand.shape
|
||||
|
@ -2029,9 +2029,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
||||
"rng": rng}
|
||||
for init_val, op, dtypes, rng in [
|
||||
(0, lax.add, [onp.float32], jtu.rand_small()),
|
||||
(-onp.inf, lax.max, [onp.float32], jtu.rand_default()),
|
||||
(onp.inf, lax.min, [onp.float32], jtu.rand_default()),
|
||||
(0, lax.add, float_dtypes, jtu.rand_small()),
|
||||
(-onp.inf, lax.max, float_dtypes, jtu.rand_default()),
|
||||
(onp.inf, lax.min, float_dtypes, jtu.rand_default()),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for padding in ["VALID", "SAME"]
|
||||
@ -2045,7 +2045,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# TODO(b/31565929): enable when fixed.
|
||||
if FLAGS.jax_test_dut == "tpu" and op is not lax.add:
|
||||
all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]
|
||||
test_gradients = False # TODO(b/73062247): need variadic reduce-window.
|
||||
test_gradients = True # TODO(b/73062247): need variadic reduce-window.
|
||||
else:
|
||||
all_configs = itertools.chain(
|
||||
itertools.product(
|
||||
|
Loading…
x
Reference in New Issue
Block a user