mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Minor cleanup of the translation rules for cumred primitives
This commit is contained in:
parent
1e9c7e4995
commit
e7766838db
@ -2574,39 +2574,27 @@ def _cumred_dtype_rule(name, operand, *args, **kw):
|
||||
"of number.".format(name, np.dtype(operand.dtype).name))
|
||||
return dtypes.canonicalize_dtype(operand.dtype)
|
||||
|
||||
cumsum_p = lax.standard_primitive(
|
||||
_cumred_shape_rule, partial(_cumred_dtype_rule, "cumsum"),
|
||||
'cumsum')
|
||||
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
|
||||
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, lax._reduce_window_sum),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
|
||||
|
||||
|
||||
def _cumulative_reduction_primitive(name, reduce_window_fn):
|
||||
def _cumulative_reduction_primitive(name,
|
||||
reduce_fn,
|
||||
tpu_reduce_window_fn):
|
||||
reducer_p = lax.standard_primitive(
|
||||
_cumred_shape_rule, partial(_cumred_dtype_rule, name),
|
||||
name)
|
||||
name,
|
||||
translation_rule=xla.lower_fun(
|
||||
partial(associative_scan, reduce_fn),
|
||||
multiple_results=False))
|
||||
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, reduce_window_fn),
|
||||
partial(_cumred_tpu_translation_rule, tpu_reduce_window_fn),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
|
||||
return reducer_p
|
||||
|
||||
cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, lax._reduce_window_sum)
|
||||
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
|
||||
cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, lax._reduce_window_prod)
|
||||
cummax_p = _cumulative_reduction_primitive("cummax", lax.max, lax._reduce_window_max)
|
||||
cummin_p = _cumulative_reduction_primitive("cummin", lax.min, lax._reduce_window_min)
|
||||
|
||||
cumprod_p = _cumulative_reduction_primitive("cumprod", lax._reduce_window_prod)
|
||||
cummax_p = _cumulative_reduction_primitive("cummax", lax._reduce_window_max)
|
||||
cummin_p = _cumulative_reduction_primitive("cummin", lax._reduce_window_min)
|
||||
|
||||
xla.translations[cumsum_p] = xla.lower_fun(
|
||||
partial(associative_scan, lax.add), multiple_results=False)
|
||||
xla.translations[cumprod_p] = xla.lower_fun(
|
||||
partial(associative_scan, lax.mul), multiple_results=False)
|
||||
xla.translations[cummin_p] = xla.lower_fun(
|
||||
partial(associative_scan, lax.min), multiple_results=False)
|
||||
xla.translations[cummax_p] = xla.lower_fun(
|
||||
partial(associative_scan, lax.max), multiple_results=False)
|
||||
|
||||
def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
|
||||
combine_fn: Callable):
|
||||
|
Loading…
x
Reference in New Issue
Block a user