diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 6890130f8..841c27780 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -34,12 +34,12 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] def genNamedParametersNArgs(n): - return parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), - "shapes": shapes, "dtypes": dtypes} - for shapes in itertools.combinations_with_replacement(all_shapes, n) - for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n))) + return parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), + "shapes": shapes, "dtypes": dtypes} + for shapes in itertools.combinations_with_replacement(all_shapes, n) + for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n))) # Allow implicit rank promotion in these tests, as virtually every test exercises it. @@ -704,7 +704,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): for func in [None, "evaluate", "logpdf", "pdf"])) def testKde(self, inshape, dtype, outsize, weights, method, func): if method == "callable": - method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4)) + method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4)) def scipy_fun(dataset, points, w): w = np.abs(w) if weights else None @@ -732,8 +732,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): rng = jtu.rand_default(self.rng()) args_maker = lambda: [ rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, - tol={np.float32: 1e-3, np.float64: 1e-14}) + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, tol={ + np.float32: 1e-2 if jtu.device_under_test() == "tpu" else 1e-3, + np.float64: 1e-14 + }) self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) @@ -887,4 +890,4 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x)) if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) + absltest.main(testLoader=jtu.JaxTestLoader())