add tests

This commit is contained in:
Erich Elsen 2020-06-28 18:21:09 +01:00
parent 7b57dc8c80
commit bf06633a87

View File

@ -57,6 +57,7 @@ _reduce = functools.reduce
Array = Any
DType = Any
Number = Union[int, float]
Shape = Sequence[int]
def _try_broadcast_shapes(shapes):
@ -1196,6 +1197,14 @@ def cumprod(operand: Array, axis: int) -> Array:
"""Computes a cumulative product along `axis`."""
return cumprod_p.bind(operand, axis=int(axis))
def cummax(operand: Array, axis: int, unit: Number) -> Array:
"""Computes a cumulative maximum along `axis'."""
return cummax_p.bind(operand, axis=int(axis), unit=unit)
def cummin(operand: Array, axis: int, unit: Number) -> Array:
"""Computes a cumulative minimum along `axis'."""
return cummin_p.bind(operand, axis=int(axis), unit=unit)
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
is_stable: bool = True) -> Union[Array, Tuple[Array, ...]]:
"""Wraps XLA's `Sort
@ -4819,6 +4828,8 @@ def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max)
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min)
def _cumred_shape_rule(x, *, axis: int):
if axis < 0 or axis >= x.ndim:
@ -4836,6 +4847,20 @@ def _cumprod_jvp_rule(primals, tangents, *, axis: int):
return api.jvp(partial(_cumprod_prefix_scan, axis=axis), primals, tangents)
def _cummax_jvp_rule(primals, tangents, *, axis: int, unit: Number):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(_cummax_prefix_scan, axis=axis, unit=unit), primals, tangents)
def _cummin_jvp_rule(primals, tangents, *, axis: int, unit=unit):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(_cummin_prefix_scan, axis=axis, unit=unit), primals, tangents)
def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, *,
axis: int):
# On TPU, an implementation using reduce_window is handled specially by the
@ -4878,6 +4903,26 @@ xla.backend_specific_translations['tpu'][cumprod_p] = xla.lower_fun(
batching.primitive_batchers[cumprod_p] = partial(_cumred_batch_rule, cumprod_p)
cummax_p = standard_primitive(
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cummax"),
'cummax', xla.lower_fun(_cummax_prefix_scan, multiple_results=False))
ad.primitive_jvps[cummax_p] = _cummax_jvp_rule
xla.backend_specific_translations['tpu'][cummax_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_max),
multiple_results=False)
batching.primitive_batchers[cummax_p] = partial(_cumred_batch_rule, cummax_p)
cummin_p = standard_primitive(
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cummin"),
'cummin', xla.lower_fun(_cummin_prefix_scan, multiple_results=False))
ad.primitive_jvps[cummin_p] = _cummin_jvp_rule
xla.backend_specific_translations['tpu'][cummin_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_min),
multiple_results=False)
batching.primitive_batchers[cummin_p] = partial(_cumred_batch_rule, cummin_p)
def _sort_abstract_eval(*args, **kwargs):
args = tuple(raise_to_shaped(arg) for arg in args)
if any(arg.shape != args[0].shape for arg in args[1:]):