mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Handle n==0 case in TPU cumsum/cumprod. (#2617)
This commit is contained in:
parent
329321b0f1
commit
36c529d4e3
@ -4118,6 +4118,8 @@ def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, axis: int):
|
||||
# prefix scan implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
padding = [(0, 0, 0)] * x.ndim
|
||||
padding[axis] = (n - 1, 0, 0)
|
||||
x = pad(x, _const(x, unit), padding)
|
||||
|
Loading…
x
Reference in New Issue
Block a user