mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #962 from hawkinsp/cumsum
Wrap np.cumsum/cumprod in a jit to avoid materializing padded output.
This commit is contained in:
commit
231a0526eb
@ -1132,8 +1132,12 @@ nanprod = _make_nan_reduction(onp.nanprod, prod, 1, nan_if_all_nan=False)
|
||||
|
||||
def _make_cumulative_reduction(onp_reduction, window_reduce, init_val,
|
||||
squash_nan=False):
|
||||
@_wraps(onp_reduction)
|
||||
def cumulative_reduction(a, axis=None, dtype=None):
|
||||
# We want to allow XLA to fuse the pad and reduce-window operators to
|
||||
# avoid materializing the padded output.
|
||||
# Consider removing `jit` once again if reduce-window is generalized to
|
||||
# support arbitrary padding.
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _cumulative_reduction(a, axis, dtype):
|
||||
if axis is None or isscalar(a):
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
@ -1166,6 +1170,11 @@ def _make_cumulative_reduction(onp_reduction, window_reduce, init_val,
|
||||
return window_reduce(
|
||||
a, window_dims, strides, xla_client.PaddingType.VALID)
|
||||
|
||||
@_wraps(onp_reduction)
|
||||
def cumulative_reduction(a, axis=None, dtype=None):
|
||||
# jit doesn't support kwargs as static_args.
|
||||
return _cumulative_reduction(a, axis, dtype)
|
||||
|
||||
return cumulative_reduction
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user