remove unit and determine automatically for all ops

This commit is contained in:
Erich Elsen 2020-06-28 20:21:35 +01:00
parent 4fe9c1d624
commit d3f6d85da5

View File

@ -4812,17 +4812,23 @@ def _prescan_power_of_two(x, axis: int, op: Callable, unit):
return x, total
def _parallel_prefix_scan(x, axis: int, op: Callable, unit=None):
if not unit and op is max:
def _parallel_prefix_scan(x, axis: int, op: Callable):
if op is max:
if onp.issubdtype(x.dtype, onp.integer):
unit = onp.iinfo(x.dtype).min
else:
unit = dtypes.finfo(x.dtype).min
elif not unit and op is min:
elif op is min:
if onp.issubdtype(x.dtype, onp.integer):
unit = onp.iinfo(x.dtype).max
else:
unit = dtypes.finfo(x.dtype).max
elif op is add:
unit = 0
elif op is mul:
unit = 1
else:
raise ValueError("Unknown type of reducer, got {}".format(op))
n = x.shape[axis]
if n == 0:
return x
@ -4836,10 +4842,10 @@ def _parallel_prefix_scan(x, axis: int, op: Callable, unit=None):
x, total = _prescan_power_of_two(x, axis, op, unit)
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=None)
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=None)
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add)
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul)
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max)
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min)
def _cumred_shape_rule(x, *, axis: int):
if axis < 0 or axis >= x.ndim:
@ -4871,20 +4877,26 @@ def _cummin_jvp_rule(primals, tangents, *, axis: int):
return api.jvp(partial(_cummin_prefix_scan, axis=axis), primals, tangents)
def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, *,
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
axis: int):
# On TPU, an implementation using reduce_window is handled specially by the
# compiler and is efficient. On other backends, it is O(n^2).
if not unit and window_reduce is _reduce_window_max:
if window_reduce is _reduce_window_max:
if onp.issubdtype(x.dtype, onp.integer):
unit = onp.iinfo(x.dtype).min
else:
unit = dtypes.finfo(x.dtype).min
elif not unit and window_reduce is _reduce_window_min:
elif window_reduce is _reduce_window_min:
if onp.issubdtype(x.dtype, onp.integer):
unit = onp.iinfo(x.dtype).max
else:
unit = dtypes.finfo(x.dtype).max
elif window_reduce is _reduce_window_sum:
unit = 0
elif window_reduce is _reduce_window_prod:
unit = 1
else:
raise ValueError("Unknown type of reducer, get {}".format(window_reduce))
n = x.shape[axis]
if n == 0:
return x
@ -4908,7 +4920,7 @@ cumsum_p = standard_primitive(
'cumsum', xla.lower_fun(_cumsum_prefix_scan, multiple_results=False))
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_sum, 0),
partial(_cumred_tpu_translation_rule, _reduce_window_sum),
multiple_results=False)
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
@ -4918,7 +4930,7 @@ cumprod_p = standard_primitive(
'cumprod', xla.lower_fun(_cumprod_prefix_scan, multiple_results=False))
ad.primitive_jvps[cumprod_p] = _cumprod_jvp_rule
xla.backend_specific_translations['tpu'][cumprod_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_prod, 1),
partial(_cumred_tpu_translation_rule, _reduce_window_prod),
multiple_results=False)
batching.primitive_batchers[cumprod_p] = partial(_cumred_batch_rule, cumprod_p)
@ -4928,7 +4940,7 @@ cummax_p = standard_primitive(
'cummax', xla.lower_fun(_cummax_prefix_scan, multiple_results=False))
ad.primitive_jvps[cummax_p] = _cummax_jvp_rule
xla.backend_specific_translations['tpu'][cummax_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_max, None),
partial(_cumred_tpu_translation_rule, _reduce_window_max),
multiple_results=False)
batching.primitive_batchers[cummax_p] = partial(_cumred_batch_rule, cummax_p)
@ -4938,7 +4950,7 @@ cummin_p = standard_primitive(
'cummin', xla.lower_fun(_cummin_prefix_scan, multiple_results=False))
ad.primitive_jvps[cummin_p] = _cummin_jvp_rule
xla.backend_specific_translations['tpu'][cummin_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_min, None),
partial(_cumred_tpu_translation_rule, _reduce_window_min),
multiple_results=False)
batching.primitive_batchers[cummin_p] = partial(_cumred_batch_rule, cummin_p)