mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
ensure lax.reduce monoid test uses original numpy (#3573)
This commit is contained in:
parent
99a43f20db
commit
11caa21eca
@ -1054,17 +1054,17 @@ def _get_monoid_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
|
||||
dtype = _dtype(x)
|
||||
if (type(aval) is ConcreteArray) and aval.shape == ():
|
||||
if monoid_op is add:
|
||||
return aval.val == 0 and _reduce_sum
|
||||
return onp.equal(aval.val, 0) and _reduce_sum
|
||||
if monoid_op is mul:
|
||||
return aval.val == 1 and _reduce_prod
|
||||
return onp.equal(aval.val, 1) and _reduce_prod
|
||||
elif monoid_op is bitwise_or and dtype == onp.bool_:
|
||||
return aval.val == _get_max_identity(dtype) and _reduce_or
|
||||
return onp.equal(aval.val, _get_max_identity(dtype)) and _reduce_or
|
||||
elif monoid_op is bitwise_and and dtype == onp.bool_:
|
||||
return aval.val == _get_min_identity(dtype) and _reduce_and
|
||||
return onp.equal(aval.val, _get_min_identity(dtype)) and _reduce_and
|
||||
elif monoid_op is max:
|
||||
return aval.val == _get_max_identity(dtype) and _reduce_max
|
||||
return onp.equal(aval.val, _get_max_identity(dtype)) and _reduce_max
|
||||
elif monoid_op is min:
|
||||
return aval.val == _get_min_identity(dtype) and _reduce_min
|
||||
return onp.equal(aval.val, _get_min_identity(dtype)) and _reduce_min
|
||||
return None
|
||||
|
||||
def _get_max_identity(dtype: DType) -> Array:
|
||||
|
Loading…
x
Reference in New Issue
Block a user