mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Fix problematic rotation tests
This commit is contained in:
parent
7e5e50114c
commit
5a2936d19d
@ -66,9 +66,8 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees)
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q: osp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees).astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees).astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -79,9 +78,8 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_matrix()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q: osp_Rotation.from_quat(q).as_matrix().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).as_matrix().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -92,9 +90,8 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_mrp()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q: osp_Rotation.from_quat(q).as_mrp().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).as_mrp().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -106,10 +103,8 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_rotvec(degrees=degrees)
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q: osp_Rotation.from_quat(q).as_rotvec(degrees=degrees).astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True,
|
||||
# tol=1e-4)
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).as_rotvec(degrees=degrees).astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -119,10 +114,9 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
def testRotationAsQuat(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q: osp_Rotation.from_quat(q).as_quat().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(jnp.where(jnp.sum(q, axis=0) > 0, q, -q)).as_quat()
|
||||
np_fn = lambda q: osp_Rotation.from_quat(onp.where(jnp.sum(q, axis=0) > 0, q, -q)).as_quat().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -135,10 +129,9 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
self.skipTest("Scipy 1.8.0 needed for concatenate.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype),)
|
||||
jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_quat()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q, o: osp_Rotation.concatenate([osp_Rotation.from_quat(q), osp_Rotation.from_quat(o)]).as_quat().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_rotvec()
|
||||
np_fn = lambda q, o: osp_Rotation.concatenate([osp_Rotation.from_quat(q), osp_Rotation.from_quat(o)]).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -149,10 +142,9 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
def testRotationGetItem(self, shape, dtype, indexer):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q)[indexer].as_quat()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q: osp_Rotation.from_quat(q)[indexer].as_quat().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(jnp.where(jnp.sum(q, axis=0) > 0, q, -q))[indexer].as_quat()
|
||||
np_fn = lambda q: osp_Rotation.from_quat(onp.where(onp.sum(q, axis=0) > 0, q, -q))[indexer].as_quat().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -166,9 +158,8 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
shape = (size, len(seq))
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda a: jsp_Rotation.from_euler(seq, a, degrees).as_rotvec()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda a: osp_Rotation.from_euler(seq, a, degrees).as_rotvec().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
np_fn = lambda a: osp_Rotation.from_euler(seq, a, degrees).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -192,9 +183,8 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda m: jsp_Rotation.from_mrp(m).as_rotvec()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda m: osp_Rotation.from_mrp(m).as_rotvec().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
np_fn = lambda m: osp_Rotation.from_mrp(m).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -204,10 +194,9 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
def testRotationFromRotvec(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda r: jsp_Rotation.from_rotvec(r).as_quat()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda r: osp_Rotation.from_rotvec(r).as_quat().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
jnp_fn = lambda r: jsp_Rotation.from_rotvec(r).as_rotvec()
|
||||
np_fn = lambda r: osp_Rotation.from_rotvec(r).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -216,10 +205,9 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
)
|
||||
def testRotationIdentity(self, num, dtype):
|
||||
args_maker = lambda: (num,)
|
||||
jnp_fn = lambda n: jsp_Rotation.identity(n, dtype).as_quat()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda n: osp_Rotation.identity(n).as_quat().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
jnp_fn = lambda n: jsp_Rotation.identity(n, dtype).as_rotvec()
|
||||
np_fn = lambda n: osp_Rotation.identity(n).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -243,10 +231,9 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype), jnp.abs(rng(shape[0], dtype)) if rng_weights else None)
|
||||
jnp_fn = lambda q, w: jsp_Rotation.from_quat(q).mean(w).as_rotvec()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
|
||||
np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
|
||||
tol = 5e-3 if jtu.device_under_test() == 'tpu' else 1e-4
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -258,9 +245,8 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype))
|
||||
jnp_fn = lambda q, o: (jsp_Rotation.from_quat(q) * jsp_Rotation.from_quat(o)).as_rotvec()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q, o: (osp_Rotation.from_quat(q) * osp_Rotation.from_quat(o)).as_rotvec().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
np_fn = lambda q, o: (osp_Rotation.from_quat(q) * osp_Rotation.from_quat(o)).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -270,10 +256,9 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
def testRotationInv(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_quat()
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q: osp_Rotation.from_quat(q).inv().as_quat().astype(dtype) # HACK
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_rotvec()
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).inv().as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -284,9 +269,8 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: len(jsp_Rotation.from_quat(q))
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q: len(osp_Rotation.from_quat(q))
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
np_fn = lambda q: len(osp_Rotation.from_quat(q))
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -297,9 +281,8 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).single
|
||||
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
|
||||
# np_fn = lambda q: osp_Rotation.from_quat(q).single
|
||||
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).single
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
|
Loading…
x
Reference in New Issue
Block a user