Merge pull request #962 from hawkinsp/cumsum

Wrap np.cumsum/cumprod in a jit to avoid materializing padded output.
This commit is contained in:
Peter Hawkins 2019-07-02 14:41:48 -04:00 committed by GitHub
commit 231a0526eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1132,8 +1132,12 @@ nanprod = _make_nan_reduction(onp.nanprod, prod, 1, nan_if_all_nan=False)
def _make_cumulative_reduction(onp_reduction, window_reduce, init_val,
squash_nan=False):
@_wraps(onp_reduction)
def cumulative_reduction(a, axis=None, dtype=None):
# We want to allow XLA to fuse the pad and reduce-window operators to
# avoid materializing the padded output.
# Consider removing `jit` once again if reduce-window is generalized to
# support arbitrary padding.
@partial(jit, static_argnums=(1, 2))
def _cumulative_reduction(a, axis, dtype):
if axis is None or isscalar(a):
a = ravel(a)
axis = 0
@ -1166,6 +1170,11 @@ def _make_cumulative_reduction(onp_reduction, window_reduce, init_val,
return window_reduce(
a, window_dims, strides, xla_client.PaddingType.VALID)
@_wraps(onp_reduction)
def cumulative_reduction(a, axis=None, dtype=None):
# jit doesn't support kwargs as static_args.
return _cumulative_reduction(a, axis, dtype)
return cumulative_reduction