diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 83827db78..3b1ea21d5 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -885,7 +885,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): ) def testKde(self, inshape, dtype, outsize, weights, method, func): if method == "callable": - method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4)) + method = lambda kde: kde.neff ** -1./(kde.d+4) def scipy_fun(dataset, points, w): w = np.abs(w) if weights else None @@ -915,11 +915,12 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)] self._CheckAgainstNumpy( scipy_fun, lax_fun, args_maker, tol={ - np.float32: 1e-2 if jtu.device_under_test() == "tpu" else 1e-3, - np.float64: 1e-14 + np.float32: 2e-2 if jtu.device_under_test() == "tpu" else 1e-3, + np.float64: 3e-14 }) self._CompileAndCheck( - lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) + lax_fun, args_maker, rtol={np.float32: 3e-5, np.float64: 3e-14}, + atol={np.float32: 3e-4, np.float64: 3e-14}) @jtu.sample_product( shape=[(15,), (3, 15), (1, 12)],