mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Rework type support for lax cumulative reductions (#3609)
This commit is contained in:
parent
420ef4e0a8
commit
db8f66d508
@ -4812,27 +4812,12 @@ def _prescan_power_of_two(x, axis: int, op: Callable, unit):
|
||||
return x, total
|
||||
|
||||
|
||||
def _parallel_prefix_scan(x, axis: int, op: Callable):
|
||||
if op is max:
|
||||
if onp.issubdtype(x.dtype, onp.integer):
|
||||
unit = onp.iinfo(x.dtype).min
|
||||
elif onp.issubdtype(x.dtype, onp.bool_):
|
||||
unit = False
|
||||
else: # inexact
|
||||
unit = -onp.inf
|
||||
elif op is min:
|
||||
if onp.issubdtype(x.dtype, onp.integer):
|
||||
def _parallel_prefix_scan(x, axis: int, op: Callable, unit: Any):
|
||||
if onp.issubdtype(x.dtype, onp.integer):
|
||||
if onp.isposinf(unit):
|
||||
unit = onp.iinfo(x.dtype).max
|
||||
elif onp.issubdtype(x.dtype, onp.bool_):
|
||||
unit = True
|
||||
else: # inexact
|
||||
unit = onp.inf
|
||||
elif op is add:
|
||||
unit = 0
|
||||
elif op is mul:
|
||||
unit = 1
|
||||
else:
|
||||
raise ValueError("Unknown type of reducer, got {}".format(op))
|
||||
elif onp.isneginf(unit):
|
||||
unit = onp.iinfo(x.dtype).min
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
@ -4846,10 +4831,10 @@ def _parallel_prefix_scan(x, axis: int, op: Callable):
|
||||
x, total = _prescan_power_of_two(x, axis, op, unit)
|
||||
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
|
||||
|
||||
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add)
|
||||
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul)
|
||||
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max)
|
||||
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min)
|
||||
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
|
||||
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
|
||||
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=-onp.inf)
|
||||
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=onp.inf)
|
||||
|
||||
def _cumred_shape_rule(x, *, axis: int):
|
||||
if axis < 0 or axis >= x.ndim:
|
||||
|
Loading…
x
Reference in New Issue
Block a user