diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 2ffc7f31c..ea7fe0d5e 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -2574,39 +2574,27 @@ def _cumred_dtype_rule(name, operand, *args, **kw): "of number.".format(name, np.dtype(operand.dtype).name)) return dtypes.canonicalize_dtype(operand.dtype) -cumsum_p = lax.standard_primitive( - _cumred_shape_rule, partial(_cumred_dtype_rule, "cumsum"), - 'cumsum') -ad.deflinear2(cumsum_p, _cumsum_transpose_rule) -xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun( - partial(_cumred_tpu_translation_rule, lax._reduce_window_sum), - multiple_results=False) -batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p) - - -def _cumulative_reduction_primitive(name, reduce_window_fn): +def _cumulative_reduction_primitive(name, + reduce_fn, + tpu_reduce_window_fn): reducer_p = lax.standard_primitive( _cumred_shape_rule, partial(_cumred_dtype_rule, name), - name) + name, + translation_rule=xla.lower_fun( + partial(associative_scan, reduce_fn), + multiple_results=False)) xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun( - partial(_cumred_tpu_translation_rule, reduce_window_fn), + partial(_cumred_tpu_translation_rule, tpu_reduce_window_fn), multiple_results=False) batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p) return reducer_p +cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, lax._reduce_window_sum) +ad.deflinear2(cumsum_p, _cumsum_transpose_rule) +cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, lax._reduce_window_prod) +cummax_p = _cumulative_reduction_primitive("cummax", lax.max, lax._reduce_window_max) +cummin_p = _cumulative_reduction_primitive("cummin", lax.min, lax._reduce_window_min) -cumprod_p = _cumulative_reduction_primitive("cumprod", lax._reduce_window_prod) -cummax_p = _cumulative_reduction_primitive("cummax", lax._reduce_window_max) -cummin_p = _cumulative_reduction_primitive("cummin", lax._reduce_window_min) - -xla.translations[cumsum_p] = xla.lower_fun( - partial(associative_scan, lax.add), multiple_results=False) -xla.translations[cumprod_p] = xla.lower_fun( - partial(associative_scan, lax.mul), multiple_results=False) -xla.translations[cummin_p] = xla.lower_fun( - partial(associative_scan, lax.min), multiple_results=False) -xla.translations[cummax_p] = xla.lower_fun( - partial(associative_scan, lax.max), multiple_results=False) def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool, combine_fn: Callable):