fix jax.numpy reduction init_val for bools

This commit is contained in:
James Bradbury 2019-08-03 21:27:06 -07:00
parent fd4b84bd95
commit d0c9f45349
3 changed files with 7 additions and 3 deletions

View File

@ -3255,7 +3255,7 @@ def _reduction_computation(c, jaxpr, consts, init_value):
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
_reduce_translation_rule)
# batching.primitive_batchers[reduce_p] = _reduce_batch_rule # TODO(mattjj): test
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
def _reduce_sum_shape_rule(operand, axes, input_shape):

View File

@ -991,6 +991,8 @@ def _reduction_dims(a, axis):
def _reduction_init_val(a, init_val):
a_dtype = xla_bridge.canonicalize_dtype(_dtype(a))
if a_dtype == 'bool':
return onp.array(init_val > 0, dtype=a_dtype)
try:
return onp.array(init_val, dtype=a_dtype)
except OverflowError:

View File

@ -1040,6 +1040,7 @@ class LaxTest(jtu.JaxTestCase):
for init_val, op, dtypes in [
(0, lax.add, default_dtypes),
(1, lax.mul, default_dtypes),
(0, lax.max, all_dtypes), # non-monoidal
(-onp.inf, lax.max, float_dtypes),
(onp.iinfo(onp.int32).min, lax.max, [onp.int32]),
# (onp.iinfo(onp.int64).min, lax.max, [onp.int64]), # TODO fails
@ -2591,14 +2592,15 @@ class LaxVmapTest(jtu.JaxTestCase):
self._CheckBatching(op, 5, bdims, (shape,), dtype, rng)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}"
{"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
bdims),
init_val, bdims),
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
"dims": dims, "bdims": bdims, "rng": rng}
for init_val, op, dtypes in [
(0, lax.add, default_dtypes),
(1, lax.mul, default_dtypes),
(0, lax.max, all_dtypes), # non-monoidal
(-onp.inf, lax.max, float_dtypes),
(onp.iinfo(onp.int32).min, lax.max, [onp.int32]),
(onp.iinfo(onp.int64).min, lax.max, [onp.int64]),