Increase precision of detrend test on TPU.

The test appears to pass at the higher tolerance these days.

PiperOrigin-RevId: 515474890
This commit is contained in:
Peter Hawkins 2023-03-09 16:31:58 -08:00 committed by jax authors
parent 50408fd694
commit eb80b17762

View File

@ -126,10 +126,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
return osp_signal.detrend(x, **kwds).astype(dtypes.to_inexact_dtype(x.dtype))
jsp_fun = partial(jsp_signal.detrend, **kwds)
if jtu.device_under_test() == 'tpu':
tol = {np.float32: 3e-2, np.float64: 1e-12}
else:
tol = {np.float32: 1e-5, np.float64: 1e-12}
tol = {np.float32: 1e-5, np.float64: 1e-12}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)