Minor cleanup of the translation rules for cumred primitives

This commit is contained in:
George Necula 2021-05-21 11:10:41 +03:00
parent 1e9c7e4995
commit e7766838db

View File

@ -2574,39 +2574,27 @@ def _cumred_dtype_rule(name, operand, *args, **kw):
"of number.".format(name, np.dtype(operand.dtype).name))
return dtypes.canonicalize_dtype(operand.dtype)
cumsum_p = lax.standard_primitive(
_cumred_shape_rule, partial(_cumred_dtype_rule, "cumsum"),
'cumsum')
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, lax._reduce_window_sum),
multiple_results=False)
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
def _cumulative_reduction_primitive(name, reduce_window_fn):
def _cumulative_reduction_primitive(name,
reduce_fn,
tpu_reduce_window_fn):
reducer_p = lax.standard_primitive(
_cumred_shape_rule, partial(_cumred_dtype_rule, name),
name)
name,
translation_rule=xla.lower_fun(
partial(associative_scan, reduce_fn),
multiple_results=False))
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, reduce_window_fn),
partial(_cumred_tpu_translation_rule, tpu_reduce_window_fn),
multiple_results=False)
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
return reducer_p
cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, lax._reduce_window_sum)
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, lax._reduce_window_prod)
cummax_p = _cumulative_reduction_primitive("cummax", lax.max, lax._reduce_window_max)
cummin_p = _cumulative_reduction_primitive("cummin", lax.min, lax._reduce_window_min)
cumprod_p = _cumulative_reduction_primitive("cumprod", lax._reduce_window_prod)
cummax_p = _cumulative_reduction_primitive("cummax", lax._reduce_window_max)
cummin_p = _cumulative_reduction_primitive("cummin", lax._reduce_window_min)
xla.translations[cumsum_p] = xla.lower_fun(
partial(associative_scan, lax.add), multiple_results=False)
xla.translations[cumprod_p] = xla.lower_fun(
partial(associative_scan, lax.mul), multiple_results=False)
xla.translations[cummin_p] = xla.lower_fun(
partial(associative_scan, lax.min), multiple_results=False)
xla.translations[cummax_p] = xla.lower_fun(
partial(associative_scan, lax.max), multiple_results=False)
def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
combine_fn: Callable):