mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[XLA:CPU] Relax test tolerances for tests using XLA:CPU.
An upcoming change to XLA:CPU will disable reassociation on floating point operators by default which is an unsound fast math optimization. This change is being made to fix numerical errors in softmax computations caused by reassocation. After that change, we will enable reassociation only in reduction operators where it is very important for performance and the XLA operator contract allows that. Since this change alters the order of operations, it may cause small numerical changes leading to test failures. This change relaxes test tolerances to make tests pass. PiperOrigin-RevId: 431453240
This commit is contained in:
parent
043561ae02
commit
c339330bc1
@ -165,8 +165,8 @@ class ODETest(jtu.JaxTestCase):
|
||||
ans = jax.grad(g)(2.) # don't crash
|
||||
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
|
||||
|
||||
atol = {jnp.float64: 5e-15}
|
||||
rtol = {jnp.float64: 2e-15}
|
||||
atol = {jnp.float32: 1e-5, jnp.float64: 5e-15}
|
||||
rtol = {jnp.float32: 1e-5, jnp.float64: 2e-15}
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=atol, rtol=rtol)
|
||||
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
|
Loading…
x
Reference in New Issue
Block a user