diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a10a73697..c6d56885a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5706,6 +5706,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertAllClose(x.trace(), jnp.array(x).trace()) self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) + @jtu.ignore_warning(category=RuntimeWarning, message="divide by zero") def testIntegerPowersArePrecise(self): # See https://github.com/jax-ml/jax/pull/3036 # Checks if the squares of float32 integers have no numerical errors. diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index 70a00e14c..ffa576850 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -117,6 +117,7 @@ class TestBFGS(jtu.JaxTestCase): jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS') self.assertLess(jax_res.fun, 1e-6) + @jtu.ignore_warning(category=RuntimeWarning, message='divide by zero') def test_minimize_bad_initial_values(self): # This test runs deliberately "bad" initial values to test that handling # of failed line search, etc. is the same across implementations