Handle n==0 case in TPU cumsum/cumprod. (#2617)

This commit is contained in:
Peter Hawkins 2020-04-06 12:33:55 -04:00 committed by GitHub
parent 329321b0f1
commit 36c529d4e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4118,6 +4118,8 @@ def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, axis: int):
# prefix scan implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
n = x.shape[axis]
if n == 0:
return x
padding = [(0, 0, 0)] * x.ndim
padding[axis] = (n - 1, 0, 0)
x = pad(x, _const(x, unit), padding)