Merge pull request #26239 from jakevdp:rotation-test

PiperOrigin-RevId: 721884090
This commit is contained in:
jax authors 2025-01-31 13:30:43 -08:00
commit b039f976d5

View File

@ -196,7 +196,10 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
)
def testRotationFromMatrix(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
def args_maker():
# Use QR to ensure valid positive-definite rotation matrix.
q, _ = onp.linalg.qr(rng(shape, dtype))
return [q]
jnp_fn = lambda m: jsp_Rotation.from_matrix(m).as_rotvec()
np_fn = lambda m: osp_Rotation.from_matrix(m).as_rotvec().astype(dtype)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)