mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Relax some test tolerances in N-D FFT tests.
PiperOrigin-RevId: 708259645
This commit is contained in:
parent
5031b6f599
commit
0b190bb665
@ -158,12 +158,13 @@ class FftTest(jtu.JaxTestCase):
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6})
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6},
|
||||
rtol={np.float32: 2e-6})
|
||||
# Test gradient for differentiable types.
|
||||
if (config.enable_x64.value and
|
||||
dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
|
||||
# TODO(skye): can we be more precise?
|
||||
tol = 0.15
|
||||
tol = 0.16
|
||||
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
|
||||
# check dtypes
|
||||
|
Loading…
x
Reference in New Issue
Block a user