actually return the primitive

This commit is contained in:
Erich Elsen 2020-06-28 20:31:30 +01:00
parent a189737ecb
commit a98249d766

View File

@ -4934,6 +4934,7 @@ def _generic_reducer_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn)
partial(_cumred_tpu_translation_rule, reduce_window_fn),
multiple_results=False)
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
return reducer_p
cumprod_p = _generic_reducer_primitive("cumprod", _cumprod_prefix_scan,