Relax test tolerances.

This makes the tests pass on CPU with a slightly different seed (+ 1).

PiperOrigin-RevId: 542877795
This commit is contained in:
Peter Hawkins 2023-06-23 09:21:32 -07:00 committed by jax authors
parent bfa113ba60
commit 0adfafe293
8 changed files with 13 additions and 13 deletions

View File

@ -629,7 +629,7 @@ class BatchingTest(jtu.JaxTestCase):
ans = vmap(lax.linalg.triangular_solve, in_axes=(1, 2))(a, b)
expected = np.stack(
[lax.linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)])
self.assertAllClose(ans, expected)
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
ans = vmap(lax.linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
expected = np.stack(
@ -639,7 +639,7 @@ class BatchingTest(jtu.JaxTestCase):
ans = vmap(lax.linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0])
expected = np.stack(
[lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
self.assertAllClose(ans, expected)
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(

View File

@ -216,7 +216,7 @@ class CustomRootTest(jtu.JaxTestCase):
# grad check with aux
jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
(a, b), order=2, rtol={jnp.float32: 1e-2})
(a, b), order=2, rtol={jnp.float32: 1e-2, np.float64: 3e-5})
# test vmap and jvp combined by jacfwd
fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y), argnums=(0, 1))

View File

@ -131,7 +131,7 @@ LAX_GRAD_OPS = [
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_float_dtypes),
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_complex_dtypes),
dtypes=grad_complex_dtypes, tol={np.float64: 2e-3}),
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes, tol={np.float64: 5e-3}),
grad_test_spec(lax.logistic, nargs=1, order=2,

View File

@ -233,8 +233,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol_spec = {np.float16: 1e-2, np.int16: 2e-7, np.int32: 1E-3,
np.float32: 1e-3, np.complex64: 1e-3, np.float64: 1e-5,
np.complex128: 1e-5}
np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3,
np.float64: 1e-5, np.complex128: 1e-5}
tol = jtu.tolerance(dtype, tol_spec)
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
@ -599,7 +599,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
jnp_fun = partial(jnp.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
np.float64: 1e-3, np.complex64: 1e-3,
np.complex128: 3e-4})
np.complex128: 5e-4})
if (jnp.issubdtype(dtype, jnp.complexfloating) and
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))

View File

@ -462,7 +462,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
k, A_mv, M, Q, H)
QA = matmul_high_precision(Q[:, :n].conj().T, A)
QAQ = matmul_high_precision(QA, Q[:, :n])
self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)
self.assertAllClose(QAQ, H.T[:n, :], rtol=2e-5, atol=2e-5)
def test_gmres_weak_types(self):
x, _ = jax.scipy.sparse.linalg.gmres(lambda x: x, 1.0)

View File

@ -393,7 +393,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
a = (a + np.conj(a.T)) / 2
return [a]
self._CheckAgainstNumpy(
np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker, tol=3e-6
np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker, tol=2e-5
)
@jtu.sample_product(
@ -1042,7 +1042,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
x, = args_maker()
p, l, u = jsp.linalg.lu(x)
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)),
rtol={np.float32: 1e-3, np.float64: 1e-12,
rtol={np.float32: 1e-3, np.float64: 5e-12,
np.complex64: 1e-3, np.complex128: 1e-12},
atol={np.float32: 1e-5})
self._CompileAndCheck(jsp.linalg.lu, args_maker)
@ -1684,7 +1684,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
w_expected, v_expected = np.linalg.eigh(np.asarray(a))
self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w),
rtol=1e-4)
rtol=1e-4, atol=1e-4)
def run_eigh_tridiagonal_test(self, alpha, beta):
n = alpha.shape[-1]

View File

@ -1109,7 +1109,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
self.assertArraysEqual(result1, result2, check_dtypes=False)
self.assertArraysAllClose(result1, result2, check_dtypes=False)
@jtu.sample_product(
inshape=[(50,), (3, 50), (2, 12)],

View File

@ -1934,7 +1934,7 @@ class BCOOTest(sptu.SparseTestCase):
jnp.array(rng(rhs_shape, rhs_dtype))]
tol = {np.float64: 1E-13, np.complex128: 1E-13,
np.float32: 1E-6, np.complex64: 1E-6}
np.float32: 2E-6, np.complex64: 2E-6}
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker, tol=tol)