ensure lax.reduce monoid test uses original numpy (#3573)

This commit is contained in:
Matthew Johnson 2020-06-26 11:44:16 -07:00 committed by GitHub
parent 99a43f20db
commit 11caa21eca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1054,17 +1054,17 @@ def _get_monoid_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
dtype = _dtype(x)
if (type(aval) is ConcreteArray) and aval.shape == ():
if monoid_op is add:
return aval.val == 0 and _reduce_sum
return onp.equal(aval.val, 0) and _reduce_sum
if monoid_op is mul:
return aval.val == 1 and _reduce_prod
return onp.equal(aval.val, 1) and _reduce_prod
elif monoid_op is bitwise_or and dtype == onp.bool_:
return aval.val == _get_max_identity(dtype) and _reduce_or
return onp.equal(aval.val, _get_max_identity(dtype)) and _reduce_or
elif monoid_op is bitwise_and and dtype == onp.bool_:
return aval.val == _get_min_identity(dtype) and _reduce_and
return onp.equal(aval.val, _get_min_identity(dtype)) and _reduce_and
elif monoid_op is max:
return aval.val == _get_max_identity(dtype) and _reduce_max
return onp.equal(aval.val, _get_max_identity(dtype)) and _reduce_max
elif monoid_op is min:
return aval.val == _get_min_identity(dtype) and _reduce_min
return onp.equal(aval.val, _get_min_identity(dtype)) and _reduce_min
return None
def _get_max_identity(dtype: DType) -> Array: