mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
actually return the primitive
This commit is contained in:
parent
a189737ecb
commit
a98249d766
@ -4934,6 +4934,7 @@ def _generic_reducer_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn)
|
||||
partial(_cumred_tpu_translation_rule, reduce_window_fn),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
|
||||
return reducer_p
|
||||
|
||||
|
||||
cumprod_p = _generic_reducer_primitive("cumprod", _cumprod_prefix_scan,
|
||||
|
Loading…
x
Reference in New Issue
Block a user