consolidate jvp rule definitions

This commit is contained in:
Erich Elsen 2020-06-28 20:39:20 +01:00
parent a98249d766
commit 1f15ffc45f

View File

@ -4856,25 +4856,12 @@ def _cumred_shape_rule(x, *, axis: int):
def _cumsum_transpose_rule(t, *, axis: int):
return [rev(cumsum(rev(t, (axis,)), axis=axis), (axis,))]
def _cumprod_jvp_rule(primals, tangents, *, axis: int):
def _cumulative_jvp_rule(primals, tangents, *, axis: int,
prefix_scan: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents)
def _cummax_jvp_rule(primals, tangents, *, axis: int):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(_cummax_prefix_scan, axis=axis), primals, tangents)
def _cummin_jvp_rule(primals, tangents, *, axis: int):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(_cummin_prefix_scan, axis=axis), primals, tangents)
return api.jvp(partial(prefix_scan, axis=axis), primals, tangents)
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
@ -4925,7 +4912,7 @@ xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
def _generic_reducer_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn):
def _cumulative_reduction_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn):
reducer_p = standard_primitive(
_cumred_shape_rule, partial(_reduce_number_dtype_rule, name),
name, xla.lower_fun(prefix_scan_fn, multiple_results=False))
@ -4937,14 +4924,20 @@ def _generic_reducer_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn)
return reducer_p
cumprod_p = _generic_reducer_primitive("cumprod", _cumprod_prefix_scan,
_cumprod_jvp_rule, _reduce_window_prod)
cumprod_p = _cumulative_reduction_primitive("cumprod", _cumprod_prefix_scan,
partial(_cumulative_jvp_rule,
prefix_scan=_cumprod_prefix_scan),
_reduce_window_prod)
cummax_p = _generic_reducer_primitive("cummax", _cummax_prefix_scan,
_cummax_jvp_rule, _reduce_window_max)
cummax_p = _cumulative_reduction_primitive("cummax", _cummax_prefix_scan,
partial(_cumulative_jvp_rule,
prefix_scan=_cummax_prefix_scan),
_reduce_window_max)
cummin_p = _generic_reducer_primitive("cummin", _cummin_prefix_scan,
_cummin_jvp_rule, _reduce_window_min)
cummin_p = _cumulative_reduction_primitive("cummin", _cummin_prefix_scan,
partial(_cumulative_jvp_rule,
prefix_scan=_cummin_prefix_scan),
_reduce_window_min)
def _sort_abstract_eval(*args, **kwargs):