mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
test: change random matrix generation for Rotation
This commit is contained in:
parent
bf9671731c
commit
4d433d063e
@ -196,7 +196,10 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
)
|
||||
def testRotationFromMatrix(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
def args_maker():
|
||||
# Use QR to ensure valid positive-definite rotation matrix.
|
||||
q, _ = onp.linalg.qr(rng(shape, dtype))
|
||||
return [q]
|
||||
jnp_fn = lambda m: jsp_Rotation.from_matrix(m).as_rotvec()
|
||||
np_fn = lambda m: osp_Rotation.from_matrix(m).as_rotvec().astype(dtype)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user