Merge pull request #17564 from andportnoy:aportnoy/scipy-spatial-test-increase-tolerance

PiperOrigin-RevId: 564891104
This commit is contained in:
jax authors 2023-09-12 18:21:31 -07:00
commit 7c0abb1c85

View File

@ -232,7 +232,7 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
args_maker = lambda: (rng(shape, dtype), jnp.abs(rng(shape[0], dtype)) if rng_weights else None)
jnp_fn = lambda q, w: jsp_Rotation.from_quat(q).mean(w).as_rotvec()
np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
tol = 5e-3 if jtu.device_under_test() == 'tpu' else 1e-4
tol = 5e-3 # 1e-4 too tight for TF32
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)