Merge pull request #16246 from chrisflesher:scipy-rotation-v3

PiperOrigin-RevId: 538788621
This commit is contained in:
jax authors 2023-06-08 08:10:58 -07:00
commit 8d27f20637
8 changed files with 754 additions and 0 deletions

View File

@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.12
* Changes
* Added {class}`scipy.spatial.transform.Rotation` and {class}`scipy.spatial.transform.Slerp`
* Deprecations
* `jax.abstract_arrays` and its contents are now deprecated. See related
functionality in :mod:`jax.core`.

View File

@ -89,6 +89,17 @@ jax.scipy.signal
stft
welch
jax.scipy.spatial.transform
---------------------------
.. automodule:: jax.scipy.spatial.transform
.. autosummary::
:toctree: _autosummary
Rotation
Slerp
jax.scipy.sparse.linalg
-----------------------

View File

@ -0,0 +1,13 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,388 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import re
import typing
import scipy.spatial.transform
import jax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
@_wraps(scipy.spatial.transform.Rotation)
class Rotation(typing.NamedTuple):
"""Rotation in 3 dimensions."""
quat: jax.Array
@classmethod
def concatenate(cls, rotations: typing.Sequence):
"""Concatenate a sequence of `Rotation` objects."""
return cls(jnp.concatenate([rotation.quat for rotation in rotations]))
@classmethod
def from_euler(cls, seq: str, angles: jax.Array, degrees: bool = False):
"""Initialize from Euler angles."""
num_axes = len(seq)
if num_axes < 1 or num_axes > 3:
raise ValueError("Expected axis specification to be a non-empty "
"string of upto 3 characters, got {}".format(seq))
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
if not (intrinsic or extrinsic):
raise ValueError("Expected axes from `seq` to be from ['x', 'y', "
"'z'] or ['X', 'Y', 'Z'], got {}".format(seq))
if any(seq[i] == seq[i+1] for i in range(num_axes - 1)):
raise ValueError("Expected consecutive axes to be different, "
"got {}".format(seq))
angles = jnp.atleast_1d(angles)
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
return cls(_elementary_quat_compose(angles, axes, intrinsic, degrees))
@classmethod
def from_matrix(cls, matrix: jax.Array):
"""Initialize from rotation matrix."""
return cls(_from_matrix(matrix))
@classmethod
def from_mrp(cls, mrp: jax.Array):
"""Initialize from Modified Rodrigues Parameters (MRPs)."""
return cls(_from_mrp(mrp))
@classmethod
def from_quat(cls, quat: jax.Array):
"""Initialize from quaternions."""
return cls(_normalize_quaternion(quat))
@classmethod
def from_rotvec(cls, rotvec: jax.Array, degrees: bool = False):
"""Initialize from rotation vectors."""
return cls(_from_rotvec(rotvec, degrees))
@classmethod
def identity(cls, num: typing.Optional[int] = None, dtype=float):
"""Get identity rotation(s)."""
assert num is None
quat = jnp.array([0., 0., 0., 1.], dtype=dtype)
return cls(quat)
@classmethod
def random(cls, random_key: jax.Array, num: typing.Optional[int] = None):
"""Generate uniformly distributed rotations."""
# Need to implement scipy.stats.special_ortho_group for this to work...
raise NotImplementedError
def __getitem__(self, indexer):
"""Extract rotation(s) at given index(es) from object."""
if self.single:
raise TypeError("Single rotation is not subscriptable.")
return Rotation(self.quat[indexer])
def __len__(self):
"""Number of rotations contained in this object."""
if self.single:
raise TypeError('Single rotation has no len().')
else:
return self.quat.shape[0]
def __mul__(self, other):
"""Compose this rotation with the other."""
return Rotation.from_quat(_compose_quat(self.quat, other.quat))
def apply(self, vectors: jax.Array, inverse: bool = False) -> jax.Array:
"""Apply this rotation to one or more vectors."""
return _apply(self.as_matrix(), vectors, inverse)
def as_euler(self, seq: str, degrees: bool = False):
"""Represent as Euler angles."""
if len(seq) != 3:
raise ValueError("Expected 3 axes, got {}.".format(seq))
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
if not (intrinsic or extrinsic):
raise ValueError("Expected axes from `seq` to be from "
"['x', 'y', 'z'] or ['X', 'Y', 'Z'], "
"got {}".format(seq))
if any(seq[i] == seq[i+1] for i in range(2)):
raise ValueError("Expected consecutive axes to be different, "
"got {}".format(seq))
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
return _compute_euler_from_quat(self.quat, axes, extrinsic, degrees)
def as_matrix(self) -> jax.Array:
"""Represent as rotation matrix."""
return _as_matrix(self.quat)
def as_mrp(self) -> jax.Array:
"""Represent as Modified Rodrigues Parameters (MRPs)."""
return _as_mrp(self.quat)
def as_rotvec(self, degrees: bool = False) -> jax.Array:
"""Represent as rotation vectors."""
return _as_rotvec(self.quat, degrees)
def as_quat(self) -> jax.Array:
"""Represent as quaternions."""
return self.quat
def inv(self):
"""Invert this rotation."""
return Rotation(_inv(self.quat))
def magnitude(self) -> jax.Array:
"""Get the magnitude(s) of the rotation(s)."""
return _magnitude(self.quat)
def mean(self, weights: typing.Optional[jax.Array] = None):
"""Get the mean of the rotations."""
weights = jnp.where(weights is None, jnp.ones(self.quat.shape[0], dtype=self.quat.dtype), jnp.asarray(weights, dtype=self.quat.dtype))
if weights.ndim != 1:
raise ValueError("Expected `weights` to be 1 dimensional, got "
"shape {}.".format(weights.shape))
if weights.shape[0] != len(self):
raise ValueError("Expected `weights` to have number of values "
"equal to number of rotations, got "
"{} values and {} rotations.".format(weights.shape[0], len(self)))
K = jnp.dot(weights[jnp.newaxis, :] * self.quat.T, self.quat)
_, v = jnp.linalg.eigh(K)
return Rotation(v[:, -1])
@property
def single(self) -> bool:
"""Whether this instance represents a single rotation."""
return self.quat.ndim == 1
@_wraps(scipy.spatial.transform.Slerp)
class Slerp(typing.NamedTuple):
"""Spherical Linear Interpolation of Rotations."""
times: jnp.ndarray
timedelta: jnp.ndarray
rotations: Rotation
rotvecs: jnp.ndarray
@classmethod
def init(cls, times: jax.Array, rotations: Rotation):
if not isinstance(rotations, Rotation):
raise TypeError("`rotations` must be a `Rotation` instance.")
if rotations.single or len(rotations) == 1:
raise ValueError("`rotations` must be a sequence of at least 2 rotations.")
times = jnp.asarray(times, dtype=rotations.quat.dtype)
if times.ndim != 1:
raise ValueError("Expected times to be specified in a 1 "
"dimensional array, got {} "
"dimensions.".format(times.ndim))
if times.shape[0] != len(rotations):
raise ValueError("Expected number of rotations to be equal to "
"number of timestamps given, got {} rotations "
"and {} timestamps.".format(len(rotations), times.shape[0]))
timedelta = jnp.diff(times)
# if jnp.any(timedelta <= 0): # this causes a concretization error...
# raise ValueError("Times must be in strictly increasing order.")
new_rotations = Rotation(rotations.as_quat()[:-1])
return cls(
times=times,
timedelta=timedelta,
rotations=new_rotations,
rotvecs=(new_rotations.inv() * Rotation(rotations.as_quat()[1:])).as_rotvec())
def __call__(self, times: jax.Array):
"""Interpolate rotations."""
compute_times = jnp.asarray(times, dtype=self.times.dtype)
if compute_times.ndim > 1:
raise ValueError("`times` must be at most 1-dimensional.")
single_time = compute_times.ndim == 0
compute_times = jnp.atleast_1d(compute_times)
ind = jnp.maximum(jnp.searchsorted(self.times, compute_times) - 1, 0)
alpha = (compute_times - self.times[ind]) / self.timedelta[ind]
result = (self.rotations[ind] * Rotation.from_rotvec(self.rotvecs[ind] * alpha[:, None]))
if single_time:
return result[0]
return result
@functools.partial(jnp.vectorize, signature='(m,m),(m),()->(m)')
def _apply(matrix: jax.Array, vector: jax.Array, inverse: bool) -> jax.Array:
return jnp.where(inverse, matrix.T, matrix) @ vector
@functools.partial(jnp.vectorize, signature='(m)->(n,n)')
def _as_matrix(quat: jax.Array) -> jax.Array:
x = quat[0]
y = quat[1]
z = quat[2]
w = quat[3]
x2 = x * x
y2 = y * y
z2 = z * z
w2 = w * w
xy = x * y
zw = z * w
xz = x * z
yw = y * w
yz = y * z
xw = x * w
return jnp.array([[+ x2 - y2 - z2 + w2, 2 * (xy - zw), 2 * (xz + yw)],
[2 * (xy + zw), - x2 + y2 - z2 + w2, 2 * (yz - xw)],
[2 * (xz - yw), 2 * (yz + xw), - x2 - y2 + z2 + w2]])
@functools.partial(jnp.vectorize, signature='(m)->(n)')
def _as_mrp(quat: jax.Array) -> jax.Array:
sign = jnp.where(quat[3] < 0, -1., 1.)
denominator = 1. + sign * quat[3]
return sign * quat[:3] / denominator
@functools.partial(jnp.vectorize, signature='(m),()->(n)')
def _as_rotvec(quat: jax.Array, degrees: bool) -> jax.Array:
quat = jnp.where(quat[3] < 0, -quat, quat) # w > 0 to ensure 0 <= angle <= pi
angle = 2. * jnp.arctan2(_vector_norm(quat[:3]), quat[3])
angle2 = angle * angle
small_scale = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
large_scale = angle / jnp.sin(angle / 2)
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
scale = jnp.where(degrees, jnp.rad2deg(scale), scale)
return scale * jnp.array(quat[:3])
@functools.partial(jnp.vectorize, signature='(n),(n)->(n)')
def _compose_quat(p: jax.Array, q: jax.Array) -> jax.Array:
cross = jnp.cross(p[:3], q[:3])
return jnp.array([p[3]*q[0] + q[3]*p[0] + cross[0],
p[3]*q[1] + q[3]*p[1] + cross[1],
p[3]*q[2] + q[3]*p[2] + cross[2],
p[3]*q[3] - p[0]*q[0] - p[1]*q[1] - p[2]*q[2]])
@functools.partial(jnp.vectorize, signature='(m),(l),(),()->(n)')
def _compute_euler_from_quat(quat: jax.Array, axes: jax.Array, extrinsic: bool, degrees: bool) -> jax.Array:
angle_first = jnp.where(extrinsic, 0, 2)
angle_third = jnp.where(extrinsic, 2, 0)
axes = jnp.where(extrinsic, axes, axes[::-1])
i = axes[0]
j = axes[1]
k = axes[2]
symmetric = i == k
k = jnp.where(symmetric, 3 - i - j, k)
sign = jnp.array((i - j) * (j - k) * (k - i) // 2, dtype=quat.dtype)
eps = 1e-7
a = jnp.where(symmetric, quat[3], quat[3] - quat[j])
b = jnp.where(symmetric, quat[i], quat[i] + quat[k] * sign)
c = jnp.where(symmetric, quat[j], quat[j] + quat[3])
d = jnp.where(symmetric, quat[k] * sign, quat[k] * sign - quat[i])
angles = jnp.empty(3, dtype=quat.dtype)
angles = angles.at[1].set(2 * jnp.arctan2(jnp.hypot(c, d), jnp.hypot(a, b)))
case = jnp.where(jnp.abs(angles[1] - jnp.pi) <= eps, 2, 0)
case = jnp.where(jnp.abs(angles[1]) <= eps, 1, case)
half_sum = jnp.arctan2(b, a)
half_diff = jnp.arctan2(d, c)
angles = angles.at[0].set(jnp.where(case == 1, 2 * half_sum, 2 * half_diff * jnp.where(extrinsic, -1, 1))) # any degenerate case
angles = angles.at[angle_first].set(jnp.where(case == 0, half_sum - half_diff, angles[angle_first]))
angles = angles.at[angle_third].set(jnp.where(case == 0, half_sum + half_diff, angles[angle_third]))
angles = angles.at[angle_third].set(jnp.where(symmetric, angles[angle_third], angles[angle_third] * sign))
angles = angles.at[1].set(jnp.where(symmetric, angles[1], angles[1] - jnp.pi / 2))
angles = (angles + jnp.pi) % (2 * jnp.pi) - jnp.pi
return jnp.where(degrees, jnp.rad2deg(angles), angles)
def _elementary_basis_index(axis: str) -> int:
if axis == 'x':
return 0
elif axis == 'y':
return 1
elif axis == 'z':
return 2
raise ValueError("Expected axis to be from ['x', 'y', 'z'], got {}".format(axis))
@functools.partial(jnp.vectorize, signature=('(m),(m),(),()->(n)'))
def _elementary_quat_compose(angles: jax.Array, axes: jax.Array, intrinsic: bool, degrees: bool) -> jax.Array:
angles = jnp.where(degrees, jnp.deg2rad(angles), angles)
result = _make_elementary_quat(axes[0], angles[0])
for idx in range(1, len(axes)):
quat = _make_elementary_quat(axes[idx], angles[idx])
result = jnp.where(intrinsic, _compose_quat(result, quat), _compose_quat(quat, result))
return result
@functools.partial(jnp.vectorize, signature=('(m),()->(n)'))
def _from_rotvec(rotvec: jax.Array, degrees: bool) -> jax.Array:
rotvec = jnp.where(degrees, jnp.deg2rad(rotvec), rotvec)
angle = _vector_norm(rotvec)
angle2 = angle * angle
small_scale = scale = 0.5 - angle2 / 48 + angle2 * angle2 / 3840
large_scale = jnp.sin(angle / 2) / angle
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
return jnp.hstack([scale * rotvec, jnp.cos(angle / 2)])
@functools.partial(jnp.vectorize, signature=('(m,m)->(n)'))
def _from_matrix(matrix: jax.Array) -> jax.Array:
matrix_trace = matrix[0, 0] + matrix[1, 1] + matrix[2, 2]
decision = jnp.array([matrix[0, 0], matrix[1, 1], matrix[2, 2], matrix_trace], dtype=matrix.dtype)
choice = jnp.argmax(decision)
i = choice
j = (i + 1) % 3
k = (j + 1) % 3
quat_012 = jnp.empty(4, dtype=matrix.dtype)
quat_012 = quat_012.at[i].set(1 - decision[3] + 2 * matrix[i, i])
quat_012 = quat_012.at[j].set(matrix[j, i] + matrix[i, j])
quat_012 = quat_012.at[k].set(matrix[k, i] + matrix[i, k])
quat_012 = quat_012.at[3].set(matrix[k, j] - matrix[j, k])
quat_3 = jnp.empty(4, dtype=matrix.dtype)
quat_3 = quat_3.at[0].set(matrix[2, 1] - matrix[1, 2])
quat_3 = quat_3.at[1].set(matrix[0, 2] - matrix[2, 0])
quat_3 = quat_3.at[2].set(matrix[1, 0] - matrix[0, 1])
quat_3 = quat_3.at[3].set(1 + decision[3])
quat = jnp.where(choice != 3, quat_012, quat_3)
return _normalize_quaternion(quat)
@functools.partial(jnp.vectorize, signature='(m)->(n)')
def _from_mrp(mrp: jax.Array) -> jax.Array:
mrp_squared_plus_1 = jnp.dot(mrp, mrp) + 1
return jnp.hstack([2 * mrp[:3], (2 - mrp_squared_plus_1)]) / mrp_squared_plus_1
@functools.partial(jnp.vectorize, signature='(n)->(n)')
def _inv(quat: jax.Array) -> jax.Array:
return quat.at[3].set(-quat[3])
@functools.partial(jnp.vectorize, signature='(n)->()')
def _magnitude(quat: jax.Array) -> jax.Array:
return 2. * jnp.arctan2(_vector_norm(quat[:3]), jnp.abs(quat[3]))
@functools.partial(jnp.vectorize, signature='(),()->(n)')
def _make_elementary_quat(axis: int, angle: jax.Array) -> jax.Array:
quat = jnp.zeros(4, dtype=angle.dtype)
quat = quat.at[3].set(jnp.cos(angle / 2.))
quat = quat.at[axis].set(jnp.sin(angle / 2.))
return quat
@functools.partial(jnp.vectorize, signature='(n)->(n)')
def _normalize_quaternion(quat: jax.Array) -> jax.Array:
return quat / _vector_norm(quat)
@functools.partial(jnp.vectorize, signature='(n)->()')
def _vector_norm(vector: jax.Array) -> jax.Array:
return jnp.sqrt(jnp.dot(vector, vector))

View File

@ -0,0 +1,13 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,21 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.spatial.transform import (
Rotation as Rotation,
Slerp as Slerp,
)

View File

@ -727,6 +727,12 @@ jax_test(
},
)
jax_test(
name = "scipy_spatial_test",
srcs = ["scipy_spatial_test.py"],
deps = py_deps("scipy"),
)
jax_test(
name = "scipy_stats_test",
srcs = ["scipy_stats_test.py"],

299
tests/scipy_spatial_test.py Normal file
View File

@ -0,0 +1,299 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import scipy.version
from jax._src import test_util as jtu
from jax.scipy.spatial.transform import Rotation as jsp_Rotation
from scipy.spatial.transform import Rotation as osp_Rotation
from jax.scipy.spatial.transform import Slerp as jsp_Slerp
from scipy.spatial.transform import Slerp as osp_Slerp
import jax.numpy as jnp
import numpy as onp
from jax.config import config
config.parse_flags_with_absl()
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
float_dtypes = jtu.dtypes.floating
real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean
num_samples = 2
class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.spatial implementations"""
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
vector_shape=[(3,), (num_samples, 3)],
inverse=[True, False],
)
def testRotationApply(self, shape, vector_shape, dtype, inverse):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), rng(vector_shape, dtype),)
jnp_fn = lambda q, v: jsp_Rotation.from_quat(q).apply(v, inverse=inverse)
np_fn = lambda q, v: osp_Rotation.from_quat(q).apply(v, inverse=inverse).astype(dtype) # HACK
tol = 5e-2 if jtu.device_under_test() == 'tpu' else 1e-4
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
seq=['xyz', 'zyx', 'XYZ', 'ZYX'],
degrees=[True, False],
)
def testRotationAsEuler(self, shape, dtype, seq, degrees):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees)
np_fn = lambda q: osp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees).astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
)
def testRotationAsMatrix(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_matrix()
np_fn = lambda q: osp_Rotation.from_quat(q).as_matrix().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
)
def testRotationAsMrp(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_mrp()
np_fn = lambda q: osp_Rotation.from_quat(q).as_mrp().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
degrees=[True, False],
)
def testRotationAsRotvec(self, shape, dtype, degrees):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_rotvec(degrees=degrees)
np_fn = lambda q: osp_Rotation.from_quat(q).as_rotvec(degrees=degrees).astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True,
tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
)
def testRotationAsQuat(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat()
np_fn = lambda q: osp_Rotation.from_quat(q).as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(num_samples, 4)],
other_shape=[(num_samples, 4)],
)
def testRotationConcatenate(self, shape, other_shape, dtype):
if scipy_version < (1, 8, 0):
self.skipTest("Scipy 1.8.0 needed for concatenate.")
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype),)
jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_quat()
np_fn = lambda q, o: osp_Rotation.concatenate([osp_Rotation.from_quat(q), osp_Rotation.from_quat(o)]).as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(10, 4)],
indexer=[slice(1, 5), slice(0), slice(-5, -3)],
)
def testRotationGetItem(self, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q)[indexer].as_quat()
np_fn = lambda q: osp_Rotation.from_quat(q)[indexer].as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
size=[1, num_samples],
seq=['x', 'xy', 'xyz', 'XYZ'],
degrees=[True, False],
)
def testRotationFromEuler(self, size, dtype, seq, degrees):
rng = jtu.rand_default(self.rng())
shape = (size, len(seq))
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda a: jsp_Rotation.from_euler(seq, a, degrees).as_rotvec()
np_fn = lambda a: osp_Rotation.from_euler(seq, a, degrees).as_rotvec().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(3, 3), (num_samples, 3, 3)],
)
def testRotationFromMatrix(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda m: jsp_Rotation.from_matrix(m).as_rotvec()
np_fn = lambda m: osp_Rotation.from_matrix(m).as_rotvec().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(3,), (num_samples, 3)],
)
def testRotationFromMrp(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda m: jsp_Rotation.from_mrp(m).as_rotvec()
np_fn = lambda m: osp_Rotation.from_mrp(m).as_rotvec().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(3,), (num_samples, 3)],
)
def testRotationFromRotvec(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda r: jsp_Rotation.from_rotvec(r).as_quat()
np_fn = lambda r: osp_Rotation.from_rotvec(r).as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
num=[None],
)
def testRotationIdentity(self, num, dtype):
args_maker = lambda: (num,)
jnp_fn = lambda n: jsp_Rotation.identity(n, dtype).as_quat()
np_fn = lambda n: osp_Rotation.identity(n).as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
)
def testRotationMagnitude(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).magnitude()
np_fn = lambda q: jnp.array(osp_Rotation.from_quat(q).magnitude(), dtype=dtype)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(num_samples, 4)],
rng_weights =[True, False],
)
def testRotationMean(self, shape, dtype, rng_weights):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), jnp.abs(rng(shape[0], dtype)) if rng_weights else None)
jnp_fn = lambda q, w: jsp_Rotation.from_quat(q).mean(w).as_rotvec()
np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
tol = 5e-3 if jtu.device_under_test() == 'tpu' else 1e-4
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
other_shape=[(4,), (num_samples, 4)],
)
def testRotationMultiply(self, shape, other_shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype))
jnp_fn = lambda q, o: (jsp_Rotation.from_quat(q) * jsp_Rotation.from_quat(o)).as_rotvec()
np_fn = lambda q, o: (osp_Rotation.from_quat(q) * osp_Rotation.from_quat(o)).as_rotvec().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
)
def testRotationInv(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_quat()
np_fn = lambda q: osp_Rotation.from_quat(q).inv().as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(num_samples, 4)],
)
def testRotationLen(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: len(jsp_Rotation.from_quat(q))
np_fn = lambda q: len(osp_Rotation.from_quat(q))
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(4,), (num_samples, 4)],
)
def testRotationSingle(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).single
np_fn = lambda q: osp_Rotation.from_quat(q).single
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
@jtu.sample_product(
dtype=float_dtypes,
shape=[(num_samples, 4)],
compute_times=[0., onp.zeros(1), onp.zeros(2)],
)
def testSlerp(self, shape, dtype, compute_times):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
times = jnp.arange(shape[0], dtype=dtype)
jnp_fn = lambda q: jsp_Slerp.init(times, jsp_Rotation.from_quat(q))(compute_times).as_rotvec()
np_fn = lambda q: osp_Slerp(times, osp_Rotation.from_quat(q))(compute_times).as_rotvec().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())