From 2f45cd725a51c124b710c06cefb7ef901aaf4851 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 15 Jul 2024 12:53:12 -0700 Subject: [PATCH] Bump some SVD test tolerances. These just barely fail on recent TPUs. PiperOrigin-RevId: 652571985 --- tests/linalg_test.py | 2 +- tests/svd_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index e687d4edb..2a64b95b9 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -909,7 +909,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): if dtype == np.complex128: atol = 2e-13 else: - atol = 5e-4 + atol = 6e-4 self.assertArraysAllClose(t_out, b.real, atol=atol) def testJspSVDBasic(self): diff --git a/tests/svd_test.py b/tests/svd_test.py index 52833b434..b349c3ca6 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -167,7 +167,7 @@ class SvdTest(jtu.JaxTestCase): np.testing.assert_almost_equal(diff, 1e-4, decimal=2) # Check that u and v are orthogonal. self.assertAllClose(u.T.conj() @ u, np.eye(m), atol=10 * _SVD_TEST_EPS) - self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=10 * _SVD_TEST_EPS) + self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=11 * _SVD_TEST_EPS) @jtu.sample_product( [dict(m=m, n=n) for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])],