mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Relax test tolerances.
This makes the tests pass on CPU with a slightly different seed (+ 1). PiperOrigin-RevId: 542877795
This commit is contained in:
parent
bfa113ba60
commit
0adfafe293
@ -629,7 +629,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
ans = vmap(lax.linalg.triangular_solve, in_axes=(1, 2))(a, b)
|
||||
expected = np.stack(
|
||||
[lax.linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)])
|
||||
self.assertAllClose(ans, expected)
|
||||
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
ans = vmap(lax.linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
|
||||
expected = np.stack(
|
||||
@ -639,7 +639,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
ans = vmap(lax.linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0])
|
||||
expected = np.stack(
|
||||
[lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
|
||||
self.assertAllClose(ans, expected)
|
||||
self.assertAllClose(ans, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
||||
|
@ -216,7 +216,7 @@ class CustomRootTest(jtu.JaxTestCase):
|
||||
|
||||
# grad check with aux
|
||||
jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
|
||||
(a, b), order=2, rtol={jnp.float32: 1e-2})
|
||||
(a, b), order=2, rtol={jnp.float32: 1e-2, np.float64: 3e-5})
|
||||
|
||||
# test vmap and jvp combined by jacfwd
|
||||
fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y), argnums=(0, 1))
|
||||
|
@ -131,7 +131,7 @@ LAX_GRAD_OPS = [
|
||||
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_positive,
|
||||
dtypes=grad_float_dtypes),
|
||||
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_complex_dtypes),
|
||||
dtypes=grad_complex_dtypes, tol={np.float64: 2e-3}),
|
||||
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_float_dtypes, tol={np.float64: 5e-3}),
|
||||
grad_test_spec(lax.logistic, nargs=1, order=2,
|
||||
|
@ -233,8 +233,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol_spec = {np.float16: 1e-2, np.int16: 2e-7, np.int32: 1E-3,
|
||||
np.float32: 1e-3, np.complex64: 1e-3, np.float64: 1e-5,
|
||||
np.complex128: 1e-5}
|
||||
np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3,
|
||||
np.float64: 1e-5, np.complex128: 1e-5}
|
||||
tol = jtu.tolerance(dtype, tol_spec)
|
||||
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
@ -599,7 +599,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
jnp_fun = partial(jnp.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
||||
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
|
||||
np.float64: 1e-3, np.complex64: 1e-3,
|
||||
np.complex128: 3e-4})
|
||||
np.complex128: 5e-4})
|
||||
if (jnp.issubdtype(dtype, jnp.complexfloating) and
|
||||
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
|
||||
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
|
||||
|
@ -462,7 +462,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
k, A_mv, M, Q, H)
|
||||
QA = matmul_high_precision(Q[:, :n].conj().T, A)
|
||||
QAQ = matmul_high_precision(QA, Q[:, :n])
|
||||
self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)
|
||||
self.assertAllClose(QAQ, H.T[:n, :], rtol=2e-5, atol=2e-5)
|
||||
|
||||
def test_gmres_weak_types(self):
|
||||
x, _ = jax.scipy.sparse.linalg.gmres(lambda x: x, 1.0)
|
||||
|
@ -393,7 +393,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
a = (a + np.conj(a.T)) / 2
|
||||
return [a]
|
||||
self._CheckAgainstNumpy(
|
||||
np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker, tol=3e-6
|
||||
np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker, tol=2e-5
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -1042,7 +1042,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
x, = args_maker()
|
||||
p, l, u = jsp.linalg.lu(x)
|
||||
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)),
|
||||
rtol={np.float32: 1e-3, np.float64: 1e-12,
|
||||
rtol={np.float32: 1e-3, np.float64: 5e-12,
|
||||
np.complex64: 1e-3, np.complex128: 1e-12},
|
||||
atol={np.float32: 1e-5})
|
||||
self._CompileAndCheck(jsp.linalg.lu, args_maker)
|
||||
@ -1684,7 +1684,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
w_expected, v_expected = np.linalg.eigh(np.asarray(a))
|
||||
self.assertAllClose(w_expected, w if sort_eigenvalues else np.sort(w),
|
||||
rtol=1e-4)
|
||||
rtol=1e-4, atol=1e-4)
|
||||
|
||||
def run_eigh_tridiagonal_test(self, alpha, beta):
|
||||
n = alpha.shape[-1]
|
||||
|
@ -1109,7 +1109,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
|
||||
result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
|
||||
self.assertArraysEqual(result1, result2, check_dtypes=False)
|
||||
self.assertArraysAllClose(result1, result2, check_dtypes=False)
|
||||
|
||||
@jtu.sample_product(
|
||||
inshape=[(50,), (3, 50), (2, 12)],
|
||||
|
@ -1934,7 +1934,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
jnp.array(rng(rhs_shape, rhs_dtype))]
|
||||
|
||||
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
||||
np.float32: 1E-6, np.complex64: 1E-6}
|
||||
np.float32: 2E-6, np.complex64: 2E-6}
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
|
||||
self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker, tol=tol)
|
||||
|
Loading…
x
Reference in New Issue
Block a user