cumsum is linear, so its gradient can be linear also. (#2618)

* cumsum is linear, so its gradient can be linear also.

* Rename _impl functions to _prefix_scan.
This commit is contained in:
Peter Hawkins 2020-04-06 15:14:22 -04:00 committed by GitHub
parent fc23e071bc
commit cf4dd84b14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4101,22 +4101,28 @@ def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
x, total = _prescan_power_of_two(x, axis, op, unit)
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
def _cumred_shape_rule(x, axis):
if axis < 0 or axis >= x.ndim:
raise ValueError(
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
return x.shape
def _cumsum_transpose_rule(t, axis: int):
return [rev(cumsum(rev(t, (axis,)), axis=axis), (axis,))]
def _cumred_jvp_rule(impl: Callable, primals, tangents, axis: int):
return api.jvp(partial(impl, axis=axis), primals, tangents)
def _cumprod_jvp_rule(primals, tangents, axis: int):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents)
def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, axis: int):
# On TPU, an implementation using reduce_window is handled specially by the
# compiler. However, irrespective of backend, we always use the parallel
# prefix scan implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
# compiler and is efficient. On other backends, it is O(n^2).
n = x.shape[axis]
if n == 0:
return x
@ -4135,23 +4141,20 @@ def _cumred_batch_rule(prim, batched_args, batch_dims, axis: int):
return prim.bind(operand, axis=axis), bdim
_cumsum_impl = partial(_parallel_prefix_scan, op=add, unit=0)
cumsum_p = standard_primitive(
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumsum"),
'cumsum', xla.lower_fun(_cumsum_impl, multiple_results=False))
ad.primitive_jvps[cumsum_p] = partial(_cumred_jvp_rule, _cumsum_impl)
'cumsum', xla.lower_fun(_cumsum_prefix_scan, multiple_results=False))
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_sum, 0),
multiple_results=False)
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
_cumprod_impl= partial(_parallel_prefix_scan, op=mul, unit=1)
cumprod_p = standard_primitive(
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumprod"),
'cumprod', xla.lower_fun(_cumprod_impl, multiple_results=False))
ad.primitive_jvps[cumprod_p] = partial(_cumred_jvp_rule, _cumprod_impl)
'cumprod', xla.lower_fun(_cumprod_prefix_scan, multiple_results=False))
ad.primitive_jvps[cumprod_p] = _cumprod_jvp_rule
xla.backend_specific_translations['tpu'][cumprod_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_prod, 1),
multiple_results=False)