mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix jax.numpy reduction init_val for bools
This commit is contained in:
parent
fd4b84bd95
commit
d0c9f45349
@ -3255,7 +3255,7 @@ def _reduction_computation(c, jaxpr, consts, init_value):
|
||||
|
||||
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
|
||||
_reduce_translation_rule)
|
||||
# batching.primitive_batchers[reduce_p] = _reduce_batch_rule # TODO(mattjj): test
|
||||
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
||||
|
||||
|
||||
def _reduce_sum_shape_rule(operand, axes, input_shape):
|
||||
|
@ -991,6 +991,8 @@ def _reduction_dims(a, axis):
|
||||
|
||||
def _reduction_init_val(a, init_val):
|
||||
a_dtype = xla_bridge.canonicalize_dtype(_dtype(a))
|
||||
if a_dtype == 'bool':
|
||||
return onp.array(init_val > 0, dtype=a_dtype)
|
||||
try:
|
||||
return onp.array(init_val, dtype=a_dtype)
|
||||
except OverflowError:
|
||||
|
@ -1040,6 +1040,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, default_dtypes),
|
||||
(1, lax.mul, default_dtypes),
|
||||
(0, lax.max, all_dtypes), # non-monoidal
|
||||
(-onp.inf, lax.max, float_dtypes),
|
||||
(onp.iinfo(onp.int32).min, lax.max, [onp.int32]),
|
||||
# (onp.iinfo(onp.int64).min, lax.max, [onp.int64]), # TODO fails
|
||||
@ -2591,14 +2592,15 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
self._CheckBatching(op, 5, bdims, (shape,), dtype, rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}"
|
||||
{"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
|
||||
bdims),
|
||||
init_val, bdims),
|
||||
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
|
||||
"dims": dims, "bdims": bdims, "rng": rng}
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, default_dtypes),
|
||||
(1, lax.mul, default_dtypes),
|
||||
(0, lax.max, all_dtypes), # non-monoidal
|
||||
(-onp.inf, lax.max, float_dtypes),
|
||||
(onp.iinfo(onp.int32).min, lax.max, [onp.int32]),
|
||||
(onp.iinfo(onp.int64).min, lax.max, [onp.int64]),
|
||||
|
Loading…
x
Reference in New Issue
Block a user