test: change random matrix generation for Rotation

This commit is contained in:
Jake VanderPlas 2025-01-31 09:21:46 -08:00
parent bf9671731c
commit 4d433d063e

View File

@ -196,7 +196,10 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
)
def testRotationFromMatrix(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
def args_maker():
# Use QR to ensure valid positive-definite rotation matrix.
q, _ = onp.linalg.qr(rng(shape, dtype))
return [q]
jnp_fn = lambda m: jsp_Rotation.from_matrix(m).as_rotvec()
np_fn = lambda m: osp_Rotation.from_matrix(m).as_rotvec().astype(dtype)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)