mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Bump some SVD test tolerances.
These just barely fail on recent TPUs. PiperOrigin-RevId: 652571985
This commit is contained in:
parent
5d3392927a
commit
2f45cd725a
@ -909,7 +909,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
if dtype == np.complex128:
|
||||
atol = 2e-13
|
||||
else:
|
||||
atol = 5e-4
|
||||
atol = 6e-4
|
||||
self.assertArraysAllClose(t_out, b.real, atol=atol)
|
||||
|
||||
def testJspSVDBasic(self):
|
||||
|
@ -167,7 +167,7 @@ class SvdTest(jtu.JaxTestCase):
|
||||
np.testing.assert_almost_equal(diff, 1e-4, decimal=2)
|
||||
# Check that u and v are orthogonal.
|
||||
self.assertAllClose(u.T.conj() @ u, np.eye(m), atol=10 * _SVD_TEST_EPS)
|
||||
self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=10 * _SVD_TEST_EPS)
|
||||
self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=11 * _SVD_TEST_EPS)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(m=m, n=n) for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])],
|
||||
|
Loading…
x
Reference in New Issue
Block a user