mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add tests
This commit is contained in:
parent
7b57dc8c80
commit
bf06633a87
@ -57,6 +57,7 @@ _reduce = functools.reduce
|
||||
|
||||
Array = Any
|
||||
DType = Any
|
||||
Number = Union[int, float]
|
||||
Shape = Sequence[int]
|
||||
|
||||
def _try_broadcast_shapes(shapes):
|
||||
@ -1196,6 +1197,14 @@ def cumprod(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative product along `axis`."""
|
||||
return cumprod_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cummax(operand: Array, axis: int, unit: Number) -> Array:
|
||||
"""Computes a cumulative maximum along `axis'."""
|
||||
return cummax_p.bind(operand, axis=int(axis), unit=unit)
|
||||
|
||||
def cummin(operand: Array, axis: int, unit: Number) -> Array:
|
||||
"""Computes a cumulative minimum along `axis'."""
|
||||
return cummin_p.bind(operand, axis=int(axis), unit=unit)
|
||||
|
||||
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
is_stable: bool = True) -> Union[Array, Tuple[Array, ...]]:
|
||||
"""Wraps XLA's `Sort
|
||||
@ -4819,6 +4828,8 @@ def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
|
||||
|
||||
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
|
||||
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
|
||||
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max)
|
||||
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min)
|
||||
|
||||
def _cumred_shape_rule(x, *, axis: int):
|
||||
if axis < 0 or axis >= x.ndim:
|
||||
@ -4836,6 +4847,20 @@ def _cumprod_jvp_rule(primals, tangents, *, axis: int):
|
||||
return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cummax_jvp_rule(primals, tangents, *, axis: int, unit: Number):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(_cummax_prefix_scan, axis=axis, unit=unit), primals, tangents)
|
||||
|
||||
|
||||
def _cummin_jvp_rule(primals, tangents, *, axis: int, unit=unit):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(_cummin_prefix_scan, axis=axis, unit=unit), primals, tangents)
|
||||
|
||||
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, *,
|
||||
axis: int):
|
||||
# On TPU, an implementation using reduce_window is handled specially by the
|
||||
@ -4878,6 +4903,26 @@ xla.backend_specific_translations['tpu'][cumprod_p] = xla.lower_fun(
|
||||
batching.primitive_batchers[cumprod_p] = partial(_cumred_batch_rule, cumprod_p)
|
||||
|
||||
|
||||
cummax_p = standard_primitive(
|
||||
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cummax"),
|
||||
'cummax', xla.lower_fun(_cummax_prefix_scan, multiple_results=False))
|
||||
ad.primitive_jvps[cummax_p] = _cummax_jvp_rule
|
||||
xla.backend_specific_translations['tpu'][cummax_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_max),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cummax_p] = partial(_cumred_batch_rule, cummax_p)
|
||||
|
||||
|
||||
cummin_p = standard_primitive(
|
||||
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cummin"),
|
||||
'cummin', xla.lower_fun(_cummin_prefix_scan, multiple_results=False))
|
||||
ad.primitive_jvps[cummin_p] = _cummin_jvp_rule
|
||||
xla.backend_specific_translations['tpu'][cummin_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_min),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cummin_p] = partial(_cumred_batch_rule, cummin_p)
|
||||
|
||||
|
||||
def _sort_abstract_eval(*args, **kwargs):
|
||||
args = tuple(raise_to_shaped(arg) for arg in args)
|
||||
if any(arg.shape != args[0].shape for arg in args[1:]):
|
||||
|
Loading…
x
Reference in New Issue
Block a user