Relax some test tolerances that appear to be sensitive to the random seed.

This commit is contained in:
Peter Hawkins 2020-12-06 15:44:44 -05:00
parent b133f14ad4
commit 03f423bb4c
3 changed files with 11 additions and 5 deletions

View File

@ -133,16 +133,16 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
expected = np.linalg.solve(posify(a), b)
actual = lax_cg(posify(a), b)
self.assertAllClose(expected, actual)
self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)
actual = jit(lax_cg)(posify(a), b)
self.assertAllClose(expected, actual)
self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)
# numerical gradients are only well defined if ``a`` is guaranteed to be
# positive definite.
jtu.check_grads(
lambda x, y: lax_cg(posify(x), y),
(a, b), order=2, rtol=1e-2)
(a, b), order=2, rtol=2e-1)
def test_cg_ndarray(self):
A = lambda x: 2 * x

View File

@ -52,6 +52,11 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
nondiff_argnums = tuple(sorted(set(nondiff_argnums)))
return OpRecord(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums, test_name)
# TODO(phawkins): we should probably separate out the function domains used for
# autodiff tests from the function domains used for equivalence testing. For
# example, logit should closely match its scipy equivalent everywhere, but we
# don't expect numerical gradient tests to pass for inputs very close to 0.
JAX_SPECIAL_FUNCTION_RECORDS = [
op_record("betaln", 2, float_dtypes, jtu.rand_positive, False),
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False),
@ -68,7 +73,8 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
op_record("i0e", 1, float_dtypes, jtu.rand_default, True),
op_record("i1", 1, float_dtypes, jtu.rand_default, True),
op_record("i1e", 1, float_dtypes, jtu.rand_default, True),
op_record("logit", 1, float_dtypes, jtu.rand_uniform, True),
op_record("logit", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05,
high=0.95), True),
op_record("log_ndtr", 1, float_dtypes, jtu.rand_default, True),
op_record("ndtri", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05,
high=0.95),

View File

@ -997,7 +997,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args))
self.assertAllClose(ps, actual_ps)
self.assertAllClose(ls, actual_ls)
self.assertAllClose(ls, actual_ls, rtol=5e-6)
self.assertAllClose(us, actual_us)
@parameterized.named_parameters(jtu.cases_from_list(