Bump some SVD test tolerances.

These just barely fail on recent TPUs.

PiperOrigin-RevId: 652571985
This commit is contained in:
Peter Hawkins 2024-07-15 12:53:12 -07:00 committed by jax authors
parent 5d3392927a
commit 2f45cd725a
2 changed files with 2 additions and 2 deletions

View File

@ -909,7 +909,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
if dtype == np.complex128:
atol = 2e-13
else:
atol = 5e-4
atol = 6e-4
self.assertArraysAllClose(t_out, b.real, atol=atol)
def testJspSVDBasic(self):

View File

@ -167,7 +167,7 @@ class SvdTest(jtu.JaxTestCase):
np.testing.assert_almost_equal(diff, 1e-4, decimal=2)
# Check that u and v are orthogonal.
self.assertAllClose(u.T.conj() @ u, np.eye(m), atol=10 * _SVD_TEST_EPS)
self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=10 * _SVD_TEST_EPS)
self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=11 * _SVD_TEST_EPS)
@jtu.sample_product(
[dict(m=m, n=n) for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])],