Implement JVP rule for reduce_prod().

This is sufficient to compute first-order derivatives of a product reduction (although not second-order derivatives because there is no JVP for reduce-window-prod).
This commit is contained in:
Peter Hawkins 2019-05-05 14:31:46 -04:00
parent 8173e671ae
commit 68f2cb4491
2 changed files with 55 additions and 0 deletions

View File

@ -707,6 +707,8 @@ def _get_monoid_reducer(monoid_op, x):
if (type(aval) is ConcreteArray) and aval.shape == ():
if monoid_op is add:
return aval.val == 0 and _reduce_sum
if monoid_op is mul:
return aval.val == 1 and _reduce_prod
elif monoid_op is max:
return aval.val == _get_max_identity(aval.dtype) and _reduce_max
elif monoid_op is min:
@ -735,6 +737,9 @@ def _get_min_identity(dtype):
def _reduce_sum(operand, axes):
return reduce_sum_p.bind(operand, axes=tuple(axes), input_shape=operand.shape)
def _reduce_prod(operand, axes):
return reduce_prod_p.bind(operand, axes=tuple(axes))
def _reduce_max(operand, axes):
return reduce_max_p.bind(operand, axes=tuple(axes))
@ -3029,6 +3034,54 @@ ad.deflinear(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p)
def _reduce_prod_shape_rule(operand, axes):
return tuple(onp.delete(operand.shape, axes))
def _reduce_prod_translation_rule(c, operand, axes):
dtype = c.GetShape(operand).numpy_dtype()
scalar = xla_bridge.Shape.array_shape(dtype, ())
return c.Reduce(operand, c.Constant(onp.array(1, dtype)),
xla.primitive_computation(mul_p, scalar, scalar),
axes)
def _reduce_prod_jvp_rule(tangent, operand, axes):
input_shape = onp.array(operand.shape)
n = onp.prod(input_shape[list(axes)])
non_axes = onp.delete(onp.arange(len(input_shape)), axes)
# Move the reduced axes to the front, and flatten them to 1D.
permutation = axes + tuple(non_axes)
new_shape = (n,) + tuple(input_shape[non_axes])
operand = reshape(operand, new_shape, permutation)
tangent = reshape(tangent, new_shape, permutation)
one = _const(operand, 1)
window_dims = [n] + [1] * len(non_axes)
window_strides = [1] * (len(non_axes) + 1)
# Form the partial products of all elements to the left and right of each
# element.
left_padding = [(n, -1, 0)] + [(0, 0, 0)] * len(non_axes)
right_padding = [(-1, n, 0)] + [(0, 0, 0)] * len(non_axes)
left_products = _reduce_window_prod(pad(operand, one, left_padding),
window_dims, window_strides,
xla_client.PaddingType.VALID)
right_products = _reduce_window_prod(pad(operand, one, right_padding),
window_dims, window_strides,
xla_client.PaddingType.VALID)
# Multiply partial products with the tangents and sum.
return _reduce_sum(mul(tangent, mul(left_products, right_products)), (0,))
reduce_prod_p = standard_primitive(_reduce_prod_shape_rule, _input_dtype,
'reduce_prod', _reduce_prod_translation_rule)
ad.defjvp(reduce_prod_p, _reduce_prod_jvp_rule)
batching.defreducer(reduce_prod_p)
def _reduce_chooser_shape_rule(operand, axes):
return tuple(onp.delete(operand.shape, axes))

View File

@ -1881,11 +1881,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
"dims": dims, "rng": rng}
for init_val, op, dtypes in [
(0, lax.add, inexact_dtypes),
(1, lax.mul, inexact_dtypes),
(-onp.inf, lax.max, inexact_dtypes),
(onp.inf, lax.min, inexact_dtypes),
]
for dtype in dtypes
for shape, dims in [
[(3, 4, 5), ()],
[(3, 4, 5), (0,)],
[(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)],