mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #17564 from andportnoy:aportnoy/scipy-spatial-test-increase-tolerance
PiperOrigin-RevId: 564891104
This commit is contained in:
commit
7c0abb1c85
@ -232,7 +232,7 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: (rng(shape, dtype), jnp.abs(rng(shape[0], dtype)) if rng_weights else None)
|
||||
jnp_fn = lambda q, w: jsp_Rotation.from_quat(q).mean(w).as_rotvec()
|
||||
np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
|
||||
tol = 5e-3 if jtu.device_under_test() == 'tpu' else 1e-4
|
||||
tol = 5e-3 # 1e-4 too tight for TF32
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user