mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Increase comparison tolerance in SciPy spatial RotationMean subtest
Previous value leads to failures on A100 runners in github.com/NVIDIA/JAX-Toolbox CI: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014 The suspected reason is the use of TF32 math for matmuls: decorating the function with @jax.default_matmul_precision("float32") allows the test to pass. We thought it's better to loosen the tolerance but preserve the original execution mode. The fully qualified test case is tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0
This commit is contained in:
parent
dbb0e8f214
commit
34ea2b2e8a
@ -232,7 +232,7 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: (rng(shape, dtype), jnp.abs(rng(shape[0], dtype)) if rng_weights else None)
|
||||
jnp_fn = lambda q, w: jsp_Rotation.from_quat(q).mean(w).as_rotvec()
|
||||
np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
|
||||
tol = 5e-3 if jtu.device_under_test() == 'tpu' else 1e-4
|
||||
tol = 5e-3 # 1e-4 too tight for TF32
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user