mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Relax some test tolerances that appear to be sensitive to the random seed.
This commit is contained in:
parent
b133f14ad4
commit
03f423bb4c
@ -133,16 +133,16 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
expected = np.linalg.solve(posify(a), b)
|
||||
actual = lax_cg(posify(a), b)
|
||||
self.assertAllClose(expected, actual)
|
||||
self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)
|
||||
|
||||
actual = jit(lax_cg)(posify(a), b)
|
||||
self.assertAllClose(expected, actual)
|
||||
self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# numerical gradients are only well defined if ``a`` is guaranteed to be
|
||||
# positive definite.
|
||||
jtu.check_grads(
|
||||
lambda x, y: lax_cg(posify(x), y),
|
||||
(a, b), order=2, rtol=1e-2)
|
||||
(a, b), order=2, rtol=2e-1)
|
||||
|
||||
def test_cg_ndarray(self):
|
||||
A = lambda x: 2 * x
|
||||
|
@ -52,6 +52,11 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
|
||||
nondiff_argnums = tuple(sorted(set(nondiff_argnums)))
|
||||
return OpRecord(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums, test_name)
|
||||
|
||||
# TODO(phawkins): we should probably separate out the function domains used for
|
||||
# autodiff tests from the function domains used for equivalence testing. For
|
||||
# example, logit should closely match its scipy equivalent everywhere, but we
|
||||
# don't expect numerical gradient tests to pass for inputs very close to 0.
|
||||
|
||||
JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
op_record("betaln", 2, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False),
|
||||
@ -68,7 +73,8 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
op_record("i0e", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("i1", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("i1e", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("logit", 1, float_dtypes, jtu.rand_uniform, True),
|
||||
op_record("logit", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05,
|
||||
high=0.95), True),
|
||||
op_record("log_ndtr", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("ndtri", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05,
|
||||
high=0.95),
|
||||
|
@ -997,7 +997,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(jnp.stack(args))
|
||||
self.assertAllClose(ps, actual_ps)
|
||||
self.assertAllClose(ls, actual_ls)
|
||||
self.assertAllClose(ls, actual_ls, rtol=5e-6)
|
||||
self.assertAllClose(us, actual_us)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user