Rework type support for lax cumulative reductions (#3609)

This commit is contained in:
Jake Vanderplas 2020-06-30 11:36:27 -07:00 committed by GitHub
parent 420ef4e0a8
commit db8f66d508
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4812,27 +4812,12 @@ def _prescan_power_of_two(x, axis: int, op: Callable, unit):
return x, total
def _parallel_prefix_scan(x, axis: int, op: Callable):
if op is max:
if onp.issubdtype(x.dtype, onp.integer):
unit = onp.iinfo(x.dtype).min
elif onp.issubdtype(x.dtype, onp.bool_):
unit = False
else: # inexact
unit = -onp.inf
elif op is min:
if onp.issubdtype(x.dtype, onp.integer):
def _parallel_prefix_scan(x, axis: int, op: Callable, unit: Any):
if onp.issubdtype(x.dtype, onp.integer):
if onp.isposinf(unit):
unit = onp.iinfo(x.dtype).max
elif onp.issubdtype(x.dtype, onp.bool_):
unit = True
else: # inexact
unit = onp.inf
elif op is add:
unit = 0
elif op is mul:
unit = 1
else:
raise ValueError("Unknown type of reducer, got {}".format(op))
elif onp.isneginf(unit):
unit = onp.iinfo(x.dtype).min
n = x.shape[axis]
if n == 0:
return x
@ -4846,10 +4831,10 @@ def _parallel_prefix_scan(x, axis: int, op: Callable):
x, total = _prescan_power_of_two(x, axis, op, unit)
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add)
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul)
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max)
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min)
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=-onp.inf)
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=onp.inf)
def _cumred_shape_rule(x, *, axis: int):
if axis < 0 or axis >= x.ndim: