diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index fe2232d7f..3da98efce 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -196,7 +196,10 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase): ) def testRotationFromMatrix(self, shape, dtype): rng = jtu.rand_default(self.rng()) - args_maker = lambda: (rng(shape, dtype),) + def args_maker(): + # Use QR to ensure valid positive-definite rotation matrix. + q, _ = onp.linalg.qr(rng(shape, dtype)) + return [q] jnp_fn = lambda m: jsp_Rotation.from_matrix(m).as_rotvec() np_fn = lambda m: osp_Rotation.from_matrix(m).as_rotvec().astype(dtype) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)