Relax some test tolerances for TPU.

PiperOrigin-RevId: 576192162
This commit is contained in:
Peter Hawkins 2023-10-24 10:44:52 -07:00 committed by jax authors
parent 8b05b1623c
commit e7f1d29716
4 changed files with 10 additions and 6 deletions

View File

@ -455,12 +455,12 @@ class IndexingTest(jtu.JaxTestCase):
# Test with traced integer index # Test with traced integer index
args_maker = lambda: [rng(size, dtype), idx_rng(size, int)] args_maker = lambda: [rng(size, dtype), idx_rng(size, int)]
atol = ( tol = (
5e-5 5e-5
if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp") if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp")
else None else None
) )
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=atol) self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=tol)
self._CompileAndCheck(jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker)
# Test with slice index # Test with slice index
@ -468,7 +468,8 @@ class IndexingTest(jtu.JaxTestCase):
np_op_idx = partial(np_op, idx=idx) np_op_idx = partial(np_op, idx=idx)
jnp_op_idx = partial(jnp_op, idx=idx) jnp_op_idx = partial(jnp_op, idx=idx)
args_maker = lambda: [rng(size, dtype)] args_maker = lambda: [rng(size, dtype)]
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=atol) self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=tol,
rtol=tol)
self._CompileAndCheck(jnp_op_idx, args_maker) self._CompileAndCheck(jnp_op_idx, args_maker)
def testIndexApplyBatchingBug(self): def testIndexApplyBatchingBug(self):

View File

@ -2565,7 +2565,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp_fun = partial(getattr(jnp, name), size, **kwds) jnp_fun = partial(getattr(jnp, name), size, **kwds)
np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds)) np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds))
args_maker = lambda: [] args_maker = lambda: []
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) tol = (
5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None
)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product( @jtu.sample_product(

View File

@ -726,7 +726,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# Check a ~= qr # Check a ~= qr
norm_error = norm(a - np.matmul(lq, lr)) norm_error = norm(a - np.matmul(lq, lr))
self.assertTrue(np.all(norm_error < 45), msg=np.amax(norm_error)) self.assertTrue(np.all(norm_error < 60), msg=np.amax(norm_error))
# Compare the first 'k' vectors of Q; the remainder form an arbitrary # Compare the first 'k' vectors of Q; the remainder form an arbitrary
# orthonormal basis for the null space. # orthonormal basis for the null space.

View File

@ -150,7 +150,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
def testSoftmaxGrad(self): def testSoftmaxGrad(self):
x = jnp.array([5.5, 1.3, -4.2, 0.9]) x = jnp.array([5.5, 1.3, -4.2, 0.9])
jtu.check_grads(nn.softmax, (x,), order=2, atol=3e-3) jtu.check_grads(nn.softmax, (x,), order=2, atol=5e-3)
def testSoftmaxGradResiduals(self): def testSoftmaxGradResiduals(self):
if not config.softmax_custom_jvp.value: if not config.softmax_custom_jvp.value: