mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Increase precision of detrend test on TPU.
The test appears to pass at the higher tolerance these days. PiperOrigin-RevId: 515474890
This commit is contained in:
parent
50408fd694
commit
eb80b17762
@ -126,10 +126,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
return osp_signal.detrend(x, **kwds).astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
jsp_fun = partial(jsp_signal.detrend, **kwds)
|
||||
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
tol = {np.float32: 3e-2, np.float64: 1e-12}
|
||||
else:
|
||||
tol = {np.float32: 1e-5, np.float64: 1e-12}
|
||||
tol = {np.float32: 1e-5, np.float64: 1e-12}
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
Loading…
x
Reference in New Issue
Block a user