mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
consolidate jvp rule definitions
This commit is contained in:
parent
a98249d766
commit
1f15ffc45f
@ -4856,25 +4856,12 @@ def _cumred_shape_rule(x, *, axis: int):
|
||||
def _cumsum_transpose_rule(t, *, axis: int):
|
||||
return [rev(cumsum(rev(t, (axis,)), axis=axis), (axis,))]
|
||||
|
||||
def _cumprod_jvp_rule(primals, tangents, *, axis: int):
|
||||
def _cumulative_jvp_rule(primals, tangents, *, axis: int,
|
||||
prefix_scan: Callable):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cummax_jvp_rule(primals, tangents, *, axis: int):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(_cummax_prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cummin_jvp_rule(primals, tangents, *, axis: int):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(_cummin_prefix_scan, axis=axis), primals, tangents)
|
||||
return api.jvp(partial(prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
|
||||
@ -4925,7 +4912,7 @@ xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
|
||||
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
|
||||
|
||||
|
||||
def _generic_reducer_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn):
|
||||
def _cumulative_reduction_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn):
|
||||
reducer_p = standard_primitive(
|
||||
_cumred_shape_rule, partial(_reduce_number_dtype_rule, name),
|
||||
name, xla.lower_fun(prefix_scan_fn, multiple_results=False))
|
||||
@ -4937,14 +4924,20 @@ def _generic_reducer_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn)
|
||||
return reducer_p
|
||||
|
||||
|
||||
cumprod_p = _generic_reducer_primitive("cumprod", _cumprod_prefix_scan,
|
||||
_cumprod_jvp_rule, _reduce_window_prod)
|
||||
cumprod_p = _cumulative_reduction_primitive("cumprod", _cumprod_prefix_scan,
|
||||
partial(_cumulative_jvp_rule,
|
||||
prefix_scan=_cumprod_prefix_scan),
|
||||
_reduce_window_prod)
|
||||
|
||||
cummax_p = _generic_reducer_primitive("cummax", _cummax_prefix_scan,
|
||||
_cummax_jvp_rule, _reduce_window_max)
|
||||
cummax_p = _cumulative_reduction_primitive("cummax", _cummax_prefix_scan,
|
||||
partial(_cumulative_jvp_rule,
|
||||
prefix_scan=_cummax_prefix_scan),
|
||||
_reduce_window_max)
|
||||
|
||||
cummin_p = _generic_reducer_primitive("cummin", _cummin_prefix_scan,
|
||||
_cummin_jvp_rule, _reduce_window_min)
|
||||
cummin_p = _cumulative_reduction_primitive("cummin", _cummin_prefix_scan,
|
||||
partial(_cumulative_jvp_rule,
|
||||
prefix_scan=_cummin_prefix_scan),
|
||||
_reduce_window_min)
|
||||
|
||||
|
||||
def _sort_abstract_eval(*args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user