mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Relax some test tolerances for TPU.
PiperOrigin-RevId: 576192162
This commit is contained in:
parent
8b05b1623c
commit
e7f1d29716
@ -455,12 +455,12 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
|
||||
# Test with traced integer index
|
||||
args_maker = lambda: [rng(size, dtype), idx_rng(size, int)]
|
||||
atol = (
|
||||
tol = (
|
||||
5e-5
|
||||
if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp")
|
||||
else None
|
||||
)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=atol)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=tol)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
# Test with slice index
|
||||
@ -468,7 +468,8 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
np_op_idx = partial(np_op, idx=idx)
|
||||
jnp_op_idx = partial(jnp_op, idx=idx)
|
||||
args_maker = lambda: [rng(size, dtype)]
|
||||
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=atol)
|
||||
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=tol,
|
||||
rtol=tol)
|
||||
self._CompileAndCheck(jnp_op_idx, args_maker)
|
||||
|
||||
def testIndexApplyBatchingBug(self):
|
||||
|
@ -2565,7 +2565,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp_fun = partial(getattr(jnp, name), size, **kwds)
|
||||
np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds))
|
||||
args_maker = lambda: []
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
tol = (
|
||||
5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None
|
||||
)
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
|
@ -726,7 +726,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
# Check a ~= qr
|
||||
norm_error = norm(a - np.matmul(lq, lr))
|
||||
self.assertTrue(np.all(norm_error < 45), msg=np.amax(norm_error))
|
||||
self.assertTrue(np.all(norm_error < 60), msg=np.amax(norm_error))
|
||||
|
||||
# Compare the first 'k' vectors of Q; the remainder form an arbitrary
|
||||
# orthonormal basis for the null space.
|
||||
|
@ -150,7 +150,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
def testSoftmaxGrad(self):
|
||||
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
||||
jtu.check_grads(nn.softmax, (x,), order=2, atol=3e-3)
|
||||
jtu.check_grads(nn.softmax, (x,), order=2, atol=5e-3)
|
||||
|
||||
def testSoftmaxGradResiduals(self):
|
||||
if not config.softmax_custom_jvp.value:
|
||||
|
Loading…
x
Reference in New Issue
Block a user