mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
BUG: fix complex warning on jnp.any/all
This commit is contained in:
parent
1f1d3dffe2
commit
618317d3b3
@ -1993,7 +1993,10 @@ def _reduction_init_val(a, init_val):
|
||||
sign, info = np.sign(init_val), iinfo(a_dtype)
|
||||
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
|
||||
|
||||
_cast_to_bool = partial(lax.convert_element_type, new_dtype=bool_)
|
||||
def _cast_to_bool(operand):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=np.ComplexWarning)
|
||||
return lax.convert_element_type(operand, bool_)
|
||||
|
||||
@_wraps(np.sum, skip_params=['out'])
|
||||
def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
|
@ -785,9 +785,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
np_fun = _promote_like_jnp(np_fun, inexact)
|
||||
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol = {np.float16: 0.002}
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol)
|
||||
|
Loading…
x
Reference in New Issue
Block a user