mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
don't require passing identity value. It isn't the initial value - identity is required for implementation correctness
This commit is contained in:
parent
e9b1aae6e1
commit
812d246295
@ -1197,13 +1197,13 @@ def cumprod(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative product along `axis`."""
|
||||
return cumprod_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cummax(operand: Array, axis: int, unit: Number) -> Array:
|
||||
def cummax(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative maximum along `axis'."""
|
||||
return cummax_p.bind(operand, axis=int(axis), unit=unit)
|
||||
return cummax_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cummin(operand: Array, axis: int, unit: Number) -> Array:
|
||||
def cummin(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative minimum along `axis'."""
|
||||
return cummin_p.bind(operand, axis=int(axis), unit=unit)
|
||||
return cummin_p.bind(operand, axis=int(axis))
|
||||
|
||||
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
is_stable: bool = True) -> Union[Array, Tuple[Array, ...]]:
|
||||
@ -4812,7 +4812,17 @@ def _prescan_power_of_two(x, axis: int, op: Callable, unit):
|
||||
return x, total
|
||||
|
||||
|
||||
def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
|
||||
def _parallel_prefix_scan(x, axis: int, op: Callable, unit=None):
|
||||
if not unit and op is max:
|
||||
if onp.issubdtype(x.dtype, onp.integer):
|
||||
unit = onp.iint(x.dtype).min
|
||||
else:
|
||||
unit = onp.fint(x.dtype).min
|
||||
elif not unit and op is min:
|
||||
if onp.issubdtype(x.dtype, onp.integer):
|
||||
unit = onp.iint(x.dtype).max
|
||||
else:
|
||||
unit = onp.fint(x.dtype).max
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
@ -4828,8 +4838,8 @@ def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
|
||||
|
||||
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
|
||||
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
|
||||
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max)
|
||||
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min)
|
||||
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=None)
|
||||
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=None)
|
||||
|
||||
def _cumred_shape_rule(x, *, axis: int):
|
||||
if axis < 0 or axis >= x.ndim:
|
||||
@ -4847,18 +4857,18 @@ def _cumprod_jvp_rule(primals, tangents, *, axis: int):
|
||||
return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cummax_jvp_rule(primals, tangents, *, axis: int, unit: Number):
|
||||
def _cummax_jvp_rule(primals, tangents, *, axis: int):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(_cummax_prefix_scan, axis=axis, unit=unit), primals, tangents)
|
||||
return api.jvp(partial(_cummax_prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cummin_jvp_rule(primals, tangents, *, axis: int, unit: Number):
|
||||
def _cummin_jvp_rule(primals, tangents, *, axis: int):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(_cummin_prefix_scan, axis=axis, unit=unit), primals, tangents)
|
||||
return api.jvp(partial(_cummin_prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, *,
|
||||
|
@ -1355,17 +1355,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testCumulativeReduceMaxMin(self, op, onp_op, shape, dtype, axis, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
fun = partial(op, axis=axis)
|
||||
if onp.issubdtype(dtype, onp.integer):
|
||||
if op == lax.cummax:
|
||||
unit = onp.iinfo(dtype).min
|
||||
else:
|
||||
unit = onp.iinfo(dtype).max
|
||||
else:
|
||||
if op == lax.cummax:
|
||||
unit = onp.finfo(dtype).min
|
||||
else:
|
||||
unit = onp.finfo(dtype).max
|
||||
onp_fun = partial(onp_op, axis=axis, dtype=dtype, unit=unit)
|
||||
onp_fun = partial(onp_op, axis=axis, dtype=dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
self._CheckAgainstNumpy(fun, onp_fun, args_maker)
|
||||
|
Loading…
x
Reference in New Issue
Block a user