don't require passing identity value. It isn't the initial value - identity is required for implementation correctness

This commit is contained in:
Erich Elsen 2020-06-28 19:33:20 +01:00
parent e9b1aae6e1
commit 812d246295
2 changed files with 22 additions and 22 deletions

View File

@ -1197,13 +1197,13 @@ def cumprod(operand: Array, axis: int) -> Array:
"""Computes a cumulative product along `axis`."""
return cumprod_p.bind(operand, axis=int(axis))
def cummax(operand: Array, axis: int, unit: Number) -> Array:
def cummax(operand: Array, axis: int) -> Array:
"""Computes a cumulative maximum along `axis'."""
return cummax_p.bind(operand, axis=int(axis), unit=unit)
return cummax_p.bind(operand, axis=int(axis))
def cummin(operand: Array, axis: int, unit: Number) -> Array:
def cummin(operand: Array, axis: int) -> Array:
"""Computes a cumulative minimum along `axis'."""
return cummin_p.bind(operand, axis=int(axis), unit=unit)
return cummin_p.bind(operand, axis=int(axis))
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
is_stable: bool = True) -> Union[Array, Tuple[Array, ...]]:
@ -4812,7 +4812,17 @@ def _prescan_power_of_two(x, axis: int, op: Callable, unit):
return x, total
def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
def _parallel_prefix_scan(x, axis: int, op: Callable, unit=None):
if not unit and op is max:
if onp.issubdtype(x.dtype, onp.integer):
unit = onp.iint(x.dtype).min
else:
unit = onp.fint(x.dtype).min
elif not unit and op is min:
if onp.issubdtype(x.dtype, onp.integer):
unit = onp.iint(x.dtype).max
else:
unit = onp.fint(x.dtype).max
n = x.shape[axis]
if n == 0:
return x
@ -4828,8 +4838,8 @@ def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max)
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min)
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=None)
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=None)
def _cumred_shape_rule(x, *, axis: int):
if axis < 0 or axis >= x.ndim:
@ -4847,18 +4857,18 @@ def _cumprod_jvp_rule(primals, tangents, *, axis: int):
return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents)
def _cummax_jvp_rule(primals, tangents, *, axis: int, unit: Number):
def _cummax_jvp_rule(primals, tangents, *, axis: int):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(_cummax_prefix_scan, axis=axis, unit=unit), primals, tangents)
return api.jvp(partial(_cummax_prefix_scan, axis=axis), primals, tangents)
def _cummin_jvp_rule(primals, tangents, *, axis: int, unit: Number):
def _cummin_jvp_rule(primals, tangents, *, axis: int):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(_cummin_prefix_scan, axis=axis, unit=unit), primals, tangents)
return api.jvp(partial(_cummin_prefix_scan, axis=axis), primals, tangents)
def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, *,

View File

@ -1355,17 +1355,7 @@ class LaxTest(jtu.JaxTestCase):
def testCumulativeReduceMaxMin(self, op, onp_op, shape, dtype, axis, rng_factory):
rng = rng_factory(self.rng())
fun = partial(op, axis=axis)
if onp.issubdtype(dtype, onp.integer):
if op == lax.cummax:
unit = onp.iinfo(dtype).min
else:
unit = onp.iinfo(dtype).max
else:
if op == lax.cummax:
unit = onp.finfo(dtype).min
else:
unit = onp.finfo(dtype).max
onp_fun = partial(onp_op, axis=axis, dtype=dtype, unit=unit)
onp_fun = partial(onp_op, axis=axis, dtype=dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker)
self._CheckAgainstNumpy(fun, onp_fun, args_maker)