From e7f1d29716df53f4109120f21930443d8fe51812 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Oct 2023 10:44:52 -0700 Subject: [PATCH] Relax some test tolerances for TPU. PiperOrigin-RevId: 576192162 --- tests/lax_numpy_indexing_test.py | 7 ++++--- tests/lax_numpy_test.py | 5 ++++- tests/linalg_test.py | 2 +- tests/nn_test.py | 2 +- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 6f88f14ad..c75a5cfa7 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -455,12 +455,12 @@ class IndexingTest(jtu.JaxTestCase): # Test with traced integer index args_maker = lambda: [rng(size, dtype), idx_rng(size, int)] - atol = ( + tol = ( 5e-5 if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp") else None ) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=atol) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=tol) self._CompileAndCheck(jnp_op, args_maker) # Test with slice index @@ -468,7 +468,8 @@ class IndexingTest(jtu.JaxTestCase): np_op_idx = partial(np_op, idx=idx) jnp_op_idx = partial(jnp_op, idx=idx) args_maker = lambda: [rng(size, dtype)] - self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=atol) + self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=tol, + rtol=tol) self._CompileAndCheck(jnp_op_idx, args_maker) def testIndexApplyBatchingBug(self): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 5f598a203..5b0b5dab5 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2565,7 +2565,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp_fun = partial(getattr(jnp, name), size, **kwds) np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds)) args_maker = lambda: [] - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + tol = ( + 5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None + ) + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 575f5943e..81f323a32 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -726,7 +726,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): # Check a ~= qr norm_error = norm(a - np.matmul(lq, lr)) - self.assertTrue(np.all(norm_error < 45), msg=np.amax(norm_error)) + self.assertTrue(np.all(norm_error < 60), msg=np.amax(norm_error)) # Compare the first 'k' vectors of Q; the remainder form an arbitrary # orthonormal basis for the null space. diff --git a/tests/nn_test.py b/tests/nn_test.py index 79b30189f..3012fb981 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -150,7 +150,7 @@ class NNFunctionsTest(jtu.JaxTestCase): def testSoftmaxGrad(self): x = jnp.array([5.5, 1.3, -4.2, 0.9]) - jtu.check_grads(nn.softmax, (x,), order=2, atol=3e-3) + jtu.check_grads(nn.softmax, (x,), order=2, atol=5e-3) def testSoftmaxGradResiduals(self): if not config.softmax_custom_jvp.value: