mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
remove unit and determine automatically for all ops
This commit is contained in:
parent
4fe9c1d624
commit
d3f6d85da5
@ -4812,17 +4812,23 @@ def _prescan_power_of_two(x, axis: int, op: Callable, unit):
|
||||
return x, total
|
||||
|
||||
|
||||
def _parallel_prefix_scan(x, axis: int, op: Callable, unit=None):
|
||||
if not unit and op is max:
|
||||
def _parallel_prefix_scan(x, axis: int, op: Callable):
|
||||
if op is max:
|
||||
if onp.issubdtype(x.dtype, onp.integer):
|
||||
unit = onp.iinfo(x.dtype).min
|
||||
else:
|
||||
unit = dtypes.finfo(x.dtype).min
|
||||
elif not unit and op is min:
|
||||
elif op is min:
|
||||
if onp.issubdtype(x.dtype, onp.integer):
|
||||
unit = onp.iinfo(x.dtype).max
|
||||
else:
|
||||
unit = dtypes.finfo(x.dtype).max
|
||||
elif op is add:
|
||||
unit = 0
|
||||
elif op is mul:
|
||||
unit = 1
|
||||
else:
|
||||
raise ValueError("Unknown type of reducer, got {}".format(op))
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
@ -4836,10 +4842,10 @@ def _parallel_prefix_scan(x, axis: int, op: Callable, unit=None):
|
||||
x, total = _prescan_power_of_two(x, axis, op, unit)
|
||||
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
|
||||
|
||||
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
|
||||
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
|
||||
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=None)
|
||||
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=None)
|
||||
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add)
|
||||
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul)
|
||||
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max)
|
||||
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min)
|
||||
|
||||
def _cumred_shape_rule(x, *, axis: int):
|
||||
if axis < 0 or axis >= x.ndim:
|
||||
@ -4871,20 +4877,26 @@ def _cummin_jvp_rule(primals, tangents, *, axis: int):
|
||||
return api.jvp(partial(_cummin_prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, *,
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
|
||||
axis: int):
|
||||
# On TPU, an implementation using reduce_window is handled specially by the
|
||||
# compiler and is efficient. On other backends, it is O(n^2).
|
||||
if not unit and window_reduce is _reduce_window_max:
|
||||
if window_reduce is _reduce_window_max:
|
||||
if onp.issubdtype(x.dtype, onp.integer):
|
||||
unit = onp.iinfo(x.dtype).min
|
||||
else:
|
||||
unit = dtypes.finfo(x.dtype).min
|
||||
elif not unit and window_reduce is _reduce_window_min:
|
||||
elif window_reduce is _reduce_window_min:
|
||||
if onp.issubdtype(x.dtype, onp.integer):
|
||||
unit = onp.iinfo(x.dtype).max
|
||||
else:
|
||||
unit = dtypes.finfo(x.dtype).max
|
||||
elif window_reduce is _reduce_window_sum:
|
||||
unit = 0
|
||||
elif window_reduce is _reduce_window_prod:
|
||||
unit = 1
|
||||
else:
|
||||
raise ValueError("Unknown type of reducer, get {}".format(window_reduce))
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
@ -4908,7 +4920,7 @@ cumsum_p = standard_primitive(
|
||||
'cumsum', xla.lower_fun(_cumsum_prefix_scan, multiple_results=False))
|
||||
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
|
||||
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_sum, 0),
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_sum),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
|
||||
|
||||
@ -4918,7 +4930,7 @@ cumprod_p = standard_primitive(
|
||||
'cumprod', xla.lower_fun(_cumprod_prefix_scan, multiple_results=False))
|
||||
ad.primitive_jvps[cumprod_p] = _cumprod_jvp_rule
|
||||
xla.backend_specific_translations['tpu'][cumprod_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_prod, 1),
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_prod),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cumprod_p] = partial(_cumred_batch_rule, cumprod_p)
|
||||
|
||||
@ -4928,7 +4940,7 @@ cummax_p = standard_primitive(
|
||||
'cummax', xla.lower_fun(_cummax_prefix_scan, multiple_results=False))
|
||||
ad.primitive_jvps[cummax_p] = _cummax_jvp_rule
|
||||
xla.backend_specific_translations['tpu'][cummax_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_max, None),
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_max),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cummax_p] = partial(_cumred_batch_rule, cummax_p)
|
||||
|
||||
@ -4938,7 +4950,7 @@ cummin_p = standard_primitive(
|
||||
'cummin', xla.lower_fun(_cummin_prefix_scan, multiple_results=False))
|
||||
ad.primitive_jvps[cummin_p] = _cummin_jvp_rule
|
||||
xla.backend_specific_translations['tpu'][cummin_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_min, None),
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_min),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cummin_p] = partial(_cumred_batch_rule, cummin_p)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user