mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
cumsum is linear, so its gradient can be linear also. (#2618)
* cumsum is linear, so its gradient can be linear also. * Rename _impl functions to _prefix_scan.
This commit is contained in:
parent
fc23e071bc
commit
cf4dd84b14
@ -4101,22 +4101,28 @@ def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
|
||||
x, total = _prescan_power_of_two(x, axis, op, unit)
|
||||
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
|
||||
|
||||
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
|
||||
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
|
||||
|
||||
def _cumred_shape_rule(x, axis):
|
||||
if axis < 0 or axis >= x.ndim:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
|
||||
return x.shape
|
||||
|
||||
def _cumsum_transpose_rule(t, axis: int):
|
||||
return [rev(cumsum(rev(t, (axis,)), axis=axis), (axis,))]
|
||||
|
||||
def _cumred_jvp_rule(impl: Callable, primals, tangents, axis: int):
|
||||
return api.jvp(partial(impl, axis=axis), primals, tangents)
|
||||
def _cumprod_jvp_rule(primals, tangents, axis: int):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, axis: int):
|
||||
# On TPU, an implementation using reduce_window is handled specially by the
|
||||
# compiler. However, irrespective of backend, we always use the parallel
|
||||
# prefix scan implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
# compiler and is efficient. On other backends, it is O(n^2).
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
@ -4135,23 +4141,20 @@ def _cumred_batch_rule(prim, batched_args, batch_dims, axis: int):
|
||||
return prim.bind(operand, axis=axis), bdim
|
||||
|
||||
|
||||
_cumsum_impl = partial(_parallel_prefix_scan, op=add, unit=0)
|
||||
|
||||
cumsum_p = standard_primitive(
|
||||
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumsum"),
|
||||
'cumsum', xla.lower_fun(_cumsum_impl, multiple_results=False))
|
||||
ad.primitive_jvps[cumsum_p] = partial(_cumred_jvp_rule, _cumsum_impl)
|
||||
'cumsum', xla.lower_fun(_cumsum_prefix_scan, multiple_results=False))
|
||||
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
|
||||
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_sum, 0),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
|
||||
|
||||
_cumprod_impl= partial(_parallel_prefix_scan, op=mul, unit=1)
|
||||
|
||||
cumprod_p = standard_primitive(
|
||||
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumprod"),
|
||||
'cumprod', xla.lower_fun(_cumprod_impl, multiple_results=False))
|
||||
ad.primitive_jvps[cumprod_p] = partial(_cumred_jvp_rule, _cumprod_impl)
|
||||
'cumprod', xla.lower_fun(_cumprod_prefix_scan, multiple_results=False))
|
||||
ad.primitive_jvps[cumprod_p] = _cumprod_jvp_rule
|
||||
xla.backend_specific_translations['tpu'][cumprod_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_prod, 1),
|
||||
multiple_results=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user