[XLA] Adjust the error tolerance of tests impacted by upcoming change in XLA:CPU

We are about to change the vectorization strategy for XLA:CPU. This change may lead
to some numerical differences due to the fact the vectorization might happen differently
(e.g., code that was scalar could now be vectorized, code that was vectorized could now
be scalar, vectorization may happen with a different VL, etc.). As a result, we have
to increase the error tolerance of the impacted tests.

PiperOrigin-RevId: 412061380
This commit is contained in:
Diego Caballero 2021-11-24 08:07:13 -08:00 committed by jax authors
parent 4e21922055
commit 4d64677277

View File

@ -286,7 +286,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
tol={np.float32: 1e-3, np.float64: 1e-14})
self._CompileAndCheck(lax_fun, args_maker)
self._CompileAndCheck(
lax_fun, args_maker, rtol={
np.float32: 3e-07,
np.float64: 4e-15
})
def testIssue980(self):
x = np.full((4,), -1e20, dtype=np.float32)