Increase comparison tolerance in SciPy spatial RotationMean subtest

Previous value leads to failures on A100 runners in
github.com/NVIDIA/JAX-Toolbox CI:
https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014

The suspected reason is the use of TF32 math for matmuls: decorating the
function with @jax.default_matmul_precision("float32") allows the test to pass.
We thought it's better to loosen the tolerance but preserve the original
execution mode.

The fully qualified test case is
tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0
This commit is contained in:
Andrey Portnoy 2023-09-12 16:12:14 -04:00
parent dbb0e8f214
commit 34ea2b2e8a

View File

@ -232,7 +232,7 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
args_maker = lambda: (rng(shape, dtype), jnp.abs(rng(shape[0], dtype)) if rng_weights else None)
jnp_fn = lambda q, w: jsp_Rotation.from_quat(q).mean(w).as_rotvec()
np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
tol = 5e-3 if jtu.device_under_test() == 'tpu' else 1e-4
tol = 5e-3 # 1e-4 too tight for TF32
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)