mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Relax some test tolerances for TPU.
PiperOrigin-RevId: 576192162
This commit is contained in:
parent
8b05b1623c
commit
e7f1d29716
@ -455,12 +455,12 @@ class IndexingTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
# Test with traced integer index
|
# Test with traced integer index
|
||||||
args_maker = lambda: [rng(size, dtype), idx_rng(size, int)]
|
args_maker = lambda: [rng(size, dtype), idx_rng(size, int)]
|
||||||
atol = (
|
tol = (
|
||||||
5e-5
|
5e-5
|
||||||
if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp")
|
if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp")
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=atol)
|
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=tol)
|
||||||
self._CompileAndCheck(jnp_op, args_maker)
|
self._CompileAndCheck(jnp_op, args_maker)
|
||||||
|
|
||||||
# Test with slice index
|
# Test with slice index
|
||||||
@ -468,7 +468,8 @@ class IndexingTest(jtu.JaxTestCase):
|
|||||||
np_op_idx = partial(np_op, idx=idx)
|
np_op_idx = partial(np_op, idx=idx)
|
||||||
jnp_op_idx = partial(jnp_op, idx=idx)
|
jnp_op_idx = partial(jnp_op, idx=idx)
|
||||||
args_maker = lambda: [rng(size, dtype)]
|
args_maker = lambda: [rng(size, dtype)]
|
||||||
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=atol)
|
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=tol,
|
||||||
|
rtol=tol)
|
||||||
self._CompileAndCheck(jnp_op_idx, args_maker)
|
self._CompileAndCheck(jnp_op_idx, args_maker)
|
||||||
|
|
||||||
def testIndexApplyBatchingBug(self):
|
def testIndexApplyBatchingBug(self):
|
||||||
|
@ -2565,7 +2565,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
jnp_fun = partial(getattr(jnp, name), size, **kwds)
|
jnp_fun = partial(getattr(jnp, name), size, **kwds)
|
||||||
np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds))
|
np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds))
|
||||||
args_maker = lambda: []
|
args_maker = lambda: []
|
||||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
tol = (
|
||||||
|
5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None
|
||||||
|
)
|
||||||
|
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol)
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
|
@ -726,7 +726,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
# Check a ~= qr
|
# Check a ~= qr
|
||||||
norm_error = norm(a - np.matmul(lq, lr))
|
norm_error = norm(a - np.matmul(lq, lr))
|
||||||
self.assertTrue(np.all(norm_error < 45), msg=np.amax(norm_error))
|
self.assertTrue(np.all(norm_error < 60), msg=np.amax(norm_error))
|
||||||
|
|
||||||
# Compare the first 'k' vectors of Q; the remainder form an arbitrary
|
# Compare the first 'k' vectors of Q; the remainder form an arbitrary
|
||||||
# orthonormal basis for the null space.
|
# orthonormal basis for the null space.
|
||||||
|
@ -150,7 +150,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
def testSoftmaxGrad(self):
|
def testSoftmaxGrad(self):
|
||||||
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
||||||
jtu.check_grads(nn.softmax, (x,), order=2, atol=3e-3)
|
jtu.check_grads(nn.softmax, (x,), order=2, atol=5e-3)
|
||||||
|
|
||||||
def testSoftmaxGradResiduals(self):
|
def testSoftmaxGradResiduals(self):
|
||||||
if not config.softmax_custom_jvp.value:
|
if not config.softmax_custom_jvp.value:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user