mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[XLA] Adjust the error tolerance of tests impacted by upcoming change in XLA:CPU
We are about to change the vectorization strategy for XLA:CPU. This change may lead to some numerical differences due to the fact the vectorization might happen differently (e.g., code that was scalar could now be vectorized, code that was vectorized could now be scalar, vectorization may happen with a different VL, etc.). As a result, we have to increase the error tolerance of the impacted tests. PiperOrigin-RevId: 412061380
This commit is contained in:
parent
4e21922055
commit
4d64677277
@ -286,7 +286,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
|
||||
tol={np.float32: 1e-3, np.float64: 1e-14})
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={
|
||||
np.float32: 3e-07,
|
||||
np.float64: 4e-15
|
||||
})
|
||||
|
||||
def testIssue980(self):
|
||||
x = np.full((4,), -1e20, dtype=np.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user