looser TPU precision

This commit is contained in:
Dan F-M 2022-06-28 17:07:30 -04:00
parent 0788d5708a
commit dc2a50ff21

View File

@ -34,12 +34,12 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
def genNamedParametersNArgs(n):
return parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
"shapes": shapes, "dtypes": dtypes}
for shapes in itertools.combinations_with_replacement(all_shapes, n)
for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
return parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
"shapes": shapes, "dtypes": dtypes}
for shapes in itertools.combinations_with_replacement(all_shapes, n)
for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
# Allow implicit rank promotion in these tests, as virtually every test exercises it.
@ -704,7 +704,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
for func in [None, "evaluate", "logpdf", "pdf"]))
def testKde(self, inshape, dtype, outsize, weights, method, func):
if method == "callable":
method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))
method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))
def scipy_fun(dataset, points, w):
w = np.abs(w) if weights else None
@ -732,8 +732,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [
rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
tol={np.float32: 1e-3, np.float64: 1e-14})
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, tol={
np.float32: 1e-2 if jtu.device_under_test() == "tpu" else 1e-3,
np.float64: 1e-14
})
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
@ -887,4 +890,4 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
absltest.main(testLoader=jtu.JaxTestLoader())