[XLA:CPU] Relax test tolerances for tests using XLA:CPU.

An upcoming change to XLA:CPU will disable reassociation on floating point operators by default which is an unsound fast math optimization. This change is being made to fix numerical errors in softmax computations caused by reassocation. After that change, we will enable reassociation only in reduction operators where it is very important for performance and the XLA operator contract allows that.

Since this change alters the order of operations, it may cause small numerical changes leading to test failures. This change relaxes test tolerances to make tests pass.

PiperOrigin-RevId: 431453240
This commit is contained in:
Peter Hawkins 2022-02-28 09:26:20 -08:00 committed by jax authors
parent 043561ae02
commit c339330bc1

View File

@ -165,8 +165,8 @@ class ODETest(jtu.JaxTestCase):
ans = jax.grad(g)(2.) # don't crash
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
atol = {jnp.float64: 5e-15}
rtol = {jnp.float64: 2e-15}
atol = {jnp.float32: 1e-5, jnp.float64: 5e-15}
rtol = {jnp.float32: 1e-5, jnp.float64: 2e-15}
self.assertAllClose(ans, expected, check_dtypes=False, atol=atol, rtol=rtol)
@jtu.skip_on_devices("tpu", "gpu")