BUG: fix complex warning on jnp.any/all

This commit is contained in:
Jake VanderPlas 2021-03-30 11:10:34 -07:00
parent 1f1d3dffe2
commit 618317d3b3
2 changed files with 4 additions and 3 deletions

View File

@ -1993,7 +1993,10 @@ def _reduction_init_val(a, init_val):
sign, info = np.sign(init_val), iinfo(a_dtype)
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
_cast_to_bool = partial(lax.convert_element_type, new_dtype=bool_)
def _cast_to_bool(operand):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=np.ComplexWarning)
return lax.convert_element_type(operand, bool_)
@_wraps(np.sum, skip_params=['out'])
def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,

View File

@ -785,9 +785,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol = {np.float16: 0.002}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)