mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #16246 from chrisflesher:scipy-rotation-v3
PiperOrigin-RevId: 538788621
This commit is contained in:
commit
8d27f20637
@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.12
|
||||
|
||||
* Changes
|
||||
* Added {class}`scipy.spatial.transform.Rotation` and {class}`scipy.spatial.transform.Slerp`
|
||||
|
||||
* Deprecations
|
||||
* `jax.abstract_arrays` and its contents are now deprecated. See related
|
||||
functionality in :mod:`jax.core`.
|
||||
|
@ -89,6 +89,17 @@ jax.scipy.signal
|
||||
stft
|
||||
welch
|
||||
|
||||
jax.scipy.spatial.transform
|
||||
---------------------------
|
||||
|
||||
.. automodule:: jax.scipy.spatial.transform
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Rotation
|
||||
Slerp
|
||||
|
||||
jax.scipy.sparse.linalg
|
||||
-----------------------
|
||||
|
||||
|
13
jax/_src/scipy/spatial/__init__.py
Normal file
13
jax/_src/scipy/spatial/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
388
jax/_src/scipy/spatial/transform.py
Normal file
388
jax/_src/scipy/spatial/transform.py
Normal file
@ -0,0 +1,388 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
|
||||
import scipy.spatial.transform
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
|
||||
|
||||
@_wraps(scipy.spatial.transform.Rotation)
|
||||
class Rotation(typing.NamedTuple):
|
||||
"""Rotation in 3 dimensions."""
|
||||
|
||||
quat: jax.Array
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, rotations: typing.Sequence):
|
||||
"""Concatenate a sequence of `Rotation` objects."""
|
||||
return cls(jnp.concatenate([rotation.quat for rotation in rotations]))
|
||||
|
||||
@classmethod
|
||||
def from_euler(cls, seq: str, angles: jax.Array, degrees: bool = False):
|
||||
"""Initialize from Euler angles."""
|
||||
num_axes = len(seq)
|
||||
if num_axes < 1 or num_axes > 3:
|
||||
raise ValueError("Expected axis specification to be a non-empty "
|
||||
"string of upto 3 characters, got {}".format(seq))
|
||||
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
|
||||
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
|
||||
if not (intrinsic or extrinsic):
|
||||
raise ValueError("Expected axes from `seq` to be from ['x', 'y', "
|
||||
"'z'] or ['X', 'Y', 'Z'], got {}".format(seq))
|
||||
if any(seq[i] == seq[i+1] for i in range(num_axes - 1)):
|
||||
raise ValueError("Expected consecutive axes to be different, "
|
||||
"got {}".format(seq))
|
||||
angles = jnp.atleast_1d(angles)
|
||||
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
|
||||
return cls(_elementary_quat_compose(angles, axes, intrinsic, degrees))
|
||||
|
||||
@classmethod
|
||||
def from_matrix(cls, matrix: jax.Array):
|
||||
"""Initialize from rotation matrix."""
|
||||
return cls(_from_matrix(matrix))
|
||||
|
||||
@classmethod
|
||||
def from_mrp(cls, mrp: jax.Array):
|
||||
"""Initialize from Modified Rodrigues Parameters (MRPs)."""
|
||||
return cls(_from_mrp(mrp))
|
||||
|
||||
@classmethod
|
||||
def from_quat(cls, quat: jax.Array):
|
||||
"""Initialize from quaternions."""
|
||||
return cls(_normalize_quaternion(quat))
|
||||
|
||||
@classmethod
|
||||
def from_rotvec(cls, rotvec: jax.Array, degrees: bool = False):
|
||||
"""Initialize from rotation vectors."""
|
||||
return cls(_from_rotvec(rotvec, degrees))
|
||||
|
||||
@classmethod
|
||||
def identity(cls, num: typing.Optional[int] = None, dtype=float):
|
||||
"""Get identity rotation(s)."""
|
||||
assert num is None
|
||||
quat = jnp.array([0., 0., 0., 1.], dtype=dtype)
|
||||
return cls(quat)
|
||||
|
||||
@classmethod
|
||||
def random(cls, random_key: jax.Array, num: typing.Optional[int] = None):
|
||||
"""Generate uniformly distributed rotations."""
|
||||
# Need to implement scipy.stats.special_ortho_group for this to work...
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, indexer):
|
||||
"""Extract rotation(s) at given index(es) from object."""
|
||||
if self.single:
|
||||
raise TypeError("Single rotation is not subscriptable.")
|
||||
return Rotation(self.quat[indexer])
|
||||
|
||||
def __len__(self):
|
||||
"""Number of rotations contained in this object."""
|
||||
if self.single:
|
||||
raise TypeError('Single rotation has no len().')
|
||||
else:
|
||||
return self.quat.shape[0]
|
||||
|
||||
def __mul__(self, other):
|
||||
"""Compose this rotation with the other."""
|
||||
return Rotation.from_quat(_compose_quat(self.quat, other.quat))
|
||||
|
||||
def apply(self, vectors: jax.Array, inverse: bool = False) -> jax.Array:
|
||||
"""Apply this rotation to one or more vectors."""
|
||||
return _apply(self.as_matrix(), vectors, inverse)
|
||||
|
||||
def as_euler(self, seq: str, degrees: bool = False):
|
||||
"""Represent as Euler angles."""
|
||||
if len(seq) != 3:
|
||||
raise ValueError("Expected 3 axes, got {}.".format(seq))
|
||||
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
|
||||
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
|
||||
if not (intrinsic or extrinsic):
|
||||
raise ValueError("Expected axes from `seq` to be from "
|
||||
"['x', 'y', 'z'] or ['X', 'Y', 'Z'], "
|
||||
"got {}".format(seq))
|
||||
if any(seq[i] == seq[i+1] for i in range(2)):
|
||||
raise ValueError("Expected consecutive axes to be different, "
|
||||
"got {}".format(seq))
|
||||
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
|
||||
return _compute_euler_from_quat(self.quat, axes, extrinsic, degrees)
|
||||
|
||||
def as_matrix(self) -> jax.Array:
|
||||
"""Represent as rotation matrix."""
|
||||
return _as_matrix(self.quat)
|
||||
|
||||
def as_mrp(self) -> jax.Array:
|
||||
"""Represent as Modified Rodrigues Parameters (MRPs)."""
|
||||
return _as_mrp(self.quat)
|
||||
|
||||
def as_rotvec(self, degrees: bool = False) -> jax.Array:
|
||||
"""Represent as rotation vectors."""
|
||||
return _as_rotvec(self.quat, degrees)
|
||||
|
||||
def as_quat(self) -> jax.Array:
|
||||
"""Represent as quaternions."""
|
||||
return self.quat
|
||||
|
||||
def inv(self):
|
||||
"""Invert this rotation."""
|
||||
return Rotation(_inv(self.quat))
|
||||
|
||||
def magnitude(self) -> jax.Array:
|
||||
"""Get the magnitude(s) of the rotation(s)."""
|
||||
return _magnitude(self.quat)
|
||||
|
||||
def mean(self, weights: typing.Optional[jax.Array] = None):
|
||||
"""Get the mean of the rotations."""
|
||||
weights = jnp.where(weights is None, jnp.ones(self.quat.shape[0], dtype=self.quat.dtype), jnp.asarray(weights, dtype=self.quat.dtype))
|
||||
if weights.ndim != 1:
|
||||
raise ValueError("Expected `weights` to be 1 dimensional, got "
|
||||
"shape {}.".format(weights.shape))
|
||||
if weights.shape[0] != len(self):
|
||||
raise ValueError("Expected `weights` to have number of values "
|
||||
"equal to number of rotations, got "
|
||||
"{} values and {} rotations.".format(weights.shape[0], len(self)))
|
||||
K = jnp.dot(weights[jnp.newaxis, :] * self.quat.T, self.quat)
|
||||
_, v = jnp.linalg.eigh(K)
|
||||
return Rotation(v[:, -1])
|
||||
|
||||
@property
|
||||
def single(self) -> bool:
|
||||
"""Whether this instance represents a single rotation."""
|
||||
return self.quat.ndim == 1
|
||||
|
||||
|
||||
@_wraps(scipy.spatial.transform.Slerp)
|
||||
class Slerp(typing.NamedTuple):
|
||||
"""Spherical Linear Interpolation of Rotations."""
|
||||
|
||||
times: jnp.ndarray
|
||||
timedelta: jnp.ndarray
|
||||
rotations: Rotation
|
||||
rotvecs: jnp.ndarray
|
||||
|
||||
@classmethod
|
||||
def init(cls, times: jax.Array, rotations: Rotation):
|
||||
if not isinstance(rotations, Rotation):
|
||||
raise TypeError("`rotations` must be a `Rotation` instance.")
|
||||
if rotations.single or len(rotations) == 1:
|
||||
raise ValueError("`rotations` must be a sequence of at least 2 rotations.")
|
||||
times = jnp.asarray(times, dtype=rotations.quat.dtype)
|
||||
if times.ndim != 1:
|
||||
raise ValueError("Expected times to be specified in a 1 "
|
||||
"dimensional array, got {} "
|
||||
"dimensions.".format(times.ndim))
|
||||
if times.shape[0] != len(rotations):
|
||||
raise ValueError("Expected number of rotations to be equal to "
|
||||
"number of timestamps given, got {} rotations "
|
||||
"and {} timestamps.".format(len(rotations), times.shape[0]))
|
||||
timedelta = jnp.diff(times)
|
||||
# if jnp.any(timedelta <= 0): # this causes a concretization error...
|
||||
# raise ValueError("Times must be in strictly increasing order.")
|
||||
new_rotations = Rotation(rotations.as_quat()[:-1])
|
||||
return cls(
|
||||
times=times,
|
||||
timedelta=timedelta,
|
||||
rotations=new_rotations,
|
||||
rotvecs=(new_rotations.inv() * Rotation(rotations.as_quat()[1:])).as_rotvec())
|
||||
|
||||
def __call__(self, times: jax.Array):
|
||||
"""Interpolate rotations."""
|
||||
compute_times = jnp.asarray(times, dtype=self.times.dtype)
|
||||
if compute_times.ndim > 1:
|
||||
raise ValueError("`times` must be at most 1-dimensional.")
|
||||
single_time = compute_times.ndim == 0
|
||||
compute_times = jnp.atleast_1d(compute_times)
|
||||
ind = jnp.maximum(jnp.searchsorted(self.times, compute_times) - 1, 0)
|
||||
alpha = (compute_times - self.times[ind]) / self.timedelta[ind]
|
||||
result = (self.rotations[ind] * Rotation.from_rotvec(self.rotvecs[ind] * alpha[:, None]))
|
||||
if single_time:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(m,m),(m),()->(m)')
|
||||
def _apply(matrix: jax.Array, vector: jax.Array, inverse: bool) -> jax.Array:
|
||||
return jnp.where(inverse, matrix.T, matrix) @ vector
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(m)->(n,n)')
|
||||
def _as_matrix(quat: jax.Array) -> jax.Array:
|
||||
x = quat[0]
|
||||
y = quat[1]
|
||||
z = quat[2]
|
||||
w = quat[3]
|
||||
x2 = x * x
|
||||
y2 = y * y
|
||||
z2 = z * z
|
||||
w2 = w * w
|
||||
xy = x * y
|
||||
zw = z * w
|
||||
xz = x * z
|
||||
yw = y * w
|
||||
yz = y * z
|
||||
xw = x * w
|
||||
return jnp.array([[+ x2 - y2 - z2 + w2, 2 * (xy - zw), 2 * (xz + yw)],
|
||||
[2 * (xy + zw), - x2 + y2 - z2 + w2, 2 * (yz - xw)],
|
||||
[2 * (xz - yw), 2 * (yz + xw), - x2 - y2 + z2 + w2]])
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(m)->(n)')
|
||||
def _as_mrp(quat: jax.Array) -> jax.Array:
|
||||
sign = jnp.where(quat[3] < 0, -1., 1.)
|
||||
denominator = 1. + sign * quat[3]
|
||||
return sign * quat[:3] / denominator
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(m),()->(n)')
|
||||
def _as_rotvec(quat: jax.Array, degrees: bool) -> jax.Array:
|
||||
quat = jnp.where(quat[3] < 0, -quat, quat) # w > 0 to ensure 0 <= angle <= pi
|
||||
angle = 2. * jnp.arctan2(_vector_norm(quat[:3]), quat[3])
|
||||
angle2 = angle * angle
|
||||
small_scale = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
|
||||
large_scale = angle / jnp.sin(angle / 2)
|
||||
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
|
||||
scale = jnp.where(degrees, jnp.rad2deg(scale), scale)
|
||||
return scale * jnp.array(quat[:3])
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(n),(n)->(n)')
|
||||
def _compose_quat(p: jax.Array, q: jax.Array) -> jax.Array:
|
||||
cross = jnp.cross(p[:3], q[:3])
|
||||
return jnp.array([p[3]*q[0] + q[3]*p[0] + cross[0],
|
||||
p[3]*q[1] + q[3]*p[1] + cross[1],
|
||||
p[3]*q[2] + q[3]*p[2] + cross[2],
|
||||
p[3]*q[3] - p[0]*q[0] - p[1]*q[1] - p[2]*q[2]])
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(m),(l),(),()->(n)')
|
||||
def _compute_euler_from_quat(quat: jax.Array, axes: jax.Array, extrinsic: bool, degrees: bool) -> jax.Array:
|
||||
angle_first = jnp.where(extrinsic, 0, 2)
|
||||
angle_third = jnp.where(extrinsic, 2, 0)
|
||||
axes = jnp.where(extrinsic, axes, axes[::-1])
|
||||
i = axes[0]
|
||||
j = axes[1]
|
||||
k = axes[2]
|
||||
symmetric = i == k
|
||||
k = jnp.where(symmetric, 3 - i - j, k)
|
||||
sign = jnp.array((i - j) * (j - k) * (k - i) // 2, dtype=quat.dtype)
|
||||
eps = 1e-7
|
||||
a = jnp.where(symmetric, quat[3], quat[3] - quat[j])
|
||||
b = jnp.where(symmetric, quat[i], quat[i] + quat[k] * sign)
|
||||
c = jnp.where(symmetric, quat[j], quat[j] + quat[3])
|
||||
d = jnp.where(symmetric, quat[k] * sign, quat[k] * sign - quat[i])
|
||||
angles = jnp.empty(3, dtype=quat.dtype)
|
||||
angles = angles.at[1].set(2 * jnp.arctan2(jnp.hypot(c, d), jnp.hypot(a, b)))
|
||||
case = jnp.where(jnp.abs(angles[1] - jnp.pi) <= eps, 2, 0)
|
||||
case = jnp.where(jnp.abs(angles[1]) <= eps, 1, case)
|
||||
half_sum = jnp.arctan2(b, a)
|
||||
half_diff = jnp.arctan2(d, c)
|
||||
angles = angles.at[0].set(jnp.where(case == 1, 2 * half_sum, 2 * half_diff * jnp.where(extrinsic, -1, 1))) # any degenerate case
|
||||
angles = angles.at[angle_first].set(jnp.where(case == 0, half_sum - half_diff, angles[angle_first]))
|
||||
angles = angles.at[angle_third].set(jnp.where(case == 0, half_sum + half_diff, angles[angle_third]))
|
||||
angles = angles.at[angle_third].set(jnp.where(symmetric, angles[angle_third], angles[angle_third] * sign))
|
||||
angles = angles.at[1].set(jnp.where(symmetric, angles[1], angles[1] - jnp.pi / 2))
|
||||
angles = (angles + jnp.pi) % (2 * jnp.pi) - jnp.pi
|
||||
return jnp.where(degrees, jnp.rad2deg(angles), angles)
|
||||
|
||||
|
||||
def _elementary_basis_index(axis: str) -> int:
|
||||
if axis == 'x':
|
||||
return 0
|
||||
elif axis == 'y':
|
||||
return 1
|
||||
elif axis == 'z':
|
||||
return 2
|
||||
raise ValueError("Expected axis to be from ['x', 'y', 'z'], got {}".format(axis))
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature=('(m),(m),(),()->(n)'))
|
||||
def _elementary_quat_compose(angles: jax.Array, axes: jax.Array, intrinsic: bool, degrees: bool) -> jax.Array:
|
||||
angles = jnp.where(degrees, jnp.deg2rad(angles), angles)
|
||||
result = _make_elementary_quat(axes[0], angles[0])
|
||||
for idx in range(1, len(axes)):
|
||||
quat = _make_elementary_quat(axes[idx], angles[idx])
|
||||
result = jnp.where(intrinsic, _compose_quat(result, quat), _compose_quat(quat, result))
|
||||
return result
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature=('(m),()->(n)'))
|
||||
def _from_rotvec(rotvec: jax.Array, degrees: bool) -> jax.Array:
|
||||
rotvec = jnp.where(degrees, jnp.deg2rad(rotvec), rotvec)
|
||||
angle = _vector_norm(rotvec)
|
||||
angle2 = angle * angle
|
||||
small_scale = scale = 0.5 - angle2 / 48 + angle2 * angle2 / 3840
|
||||
large_scale = jnp.sin(angle / 2) / angle
|
||||
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
|
||||
return jnp.hstack([scale * rotvec, jnp.cos(angle / 2)])
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature=('(m,m)->(n)'))
|
||||
def _from_matrix(matrix: jax.Array) -> jax.Array:
|
||||
matrix_trace = matrix[0, 0] + matrix[1, 1] + matrix[2, 2]
|
||||
decision = jnp.array([matrix[0, 0], matrix[1, 1], matrix[2, 2], matrix_trace], dtype=matrix.dtype)
|
||||
choice = jnp.argmax(decision)
|
||||
i = choice
|
||||
j = (i + 1) % 3
|
||||
k = (j + 1) % 3
|
||||
quat_012 = jnp.empty(4, dtype=matrix.dtype)
|
||||
quat_012 = quat_012.at[i].set(1 - decision[3] + 2 * matrix[i, i])
|
||||
quat_012 = quat_012.at[j].set(matrix[j, i] + matrix[i, j])
|
||||
quat_012 = quat_012.at[k].set(matrix[k, i] + matrix[i, k])
|
||||
quat_012 = quat_012.at[3].set(matrix[k, j] - matrix[j, k])
|
||||
quat_3 = jnp.empty(4, dtype=matrix.dtype)
|
||||
quat_3 = quat_3.at[0].set(matrix[2, 1] - matrix[1, 2])
|
||||
quat_3 = quat_3.at[1].set(matrix[0, 2] - matrix[2, 0])
|
||||
quat_3 = quat_3.at[2].set(matrix[1, 0] - matrix[0, 1])
|
||||
quat_3 = quat_3.at[3].set(1 + decision[3])
|
||||
quat = jnp.where(choice != 3, quat_012, quat_3)
|
||||
return _normalize_quaternion(quat)
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(m)->(n)')
|
||||
def _from_mrp(mrp: jax.Array) -> jax.Array:
|
||||
mrp_squared_plus_1 = jnp.dot(mrp, mrp) + 1
|
||||
return jnp.hstack([2 * mrp[:3], (2 - mrp_squared_plus_1)]) / mrp_squared_plus_1
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(n)->(n)')
|
||||
def _inv(quat: jax.Array) -> jax.Array:
|
||||
return quat.at[3].set(-quat[3])
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(n)->()')
|
||||
def _magnitude(quat: jax.Array) -> jax.Array:
|
||||
return 2. * jnp.arctan2(_vector_norm(quat[:3]), jnp.abs(quat[3]))
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(),()->(n)')
|
||||
def _make_elementary_quat(axis: int, angle: jax.Array) -> jax.Array:
|
||||
quat = jnp.zeros(4, dtype=angle.dtype)
|
||||
quat = quat.at[3].set(jnp.cos(angle / 2.))
|
||||
quat = quat.at[axis].set(jnp.sin(angle / 2.))
|
||||
return quat
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(n)->(n)')
|
||||
def _normalize_quaternion(quat: jax.Array) -> jax.Array:
|
||||
return quat / _vector_norm(quat)
|
||||
|
||||
|
||||
@functools.partial(jnp.vectorize, signature='(n)->()')
|
||||
def _vector_norm(vector: jax.Array) -> jax.Array:
|
||||
return jnp.sqrt(jnp.dot(vector, vector))
|
13
jax/scipy/spatial/__init__.py
Normal file
13
jax/scipy/spatial/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
21
jax/scipy/spatial/transform.py
Normal file
21
jax/scipy/spatial/transform.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.scipy.spatial.transform import (
|
||||
Rotation as Rotation,
|
||||
Slerp as Slerp,
|
||||
)
|
@ -727,6 +727,12 @@ jax_test(
|
||||
},
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "scipy_spatial_test",
|
||||
srcs = ["scipy_spatial_test.py"],
|
||||
deps = py_deps("scipy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "scipy_stats_test",
|
||||
srcs = ["scipy_stats_test.py"],
|
||||
|
299
tests/scipy_spatial_test.py
Normal file
299
tests/scipy_spatial_test.py
Normal file
@ -0,0 +1,299 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import scipy.version
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy.spatial.transform import Rotation as jsp_Rotation
|
||||
from scipy.spatial.transform import Rotation as osp_Rotation
|
||||
from jax.scipy.spatial.transform import Slerp as jsp_Slerp
|
||||
from scipy.spatial.transform import Slerp as osp_Slerp
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
from jax.config import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
|
||||
|
||||
float_dtypes = jtu.dtypes.floating
|
||||
real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean
|
||||
|
||||
num_samples = 2
|
||||
|
||||
class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed scipy.spatial implementations"""
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
vector_shape=[(3,), (num_samples, 3)],
|
||||
inverse=[True, False],
|
||||
)
|
||||
def testRotationApply(self, shape, vector_shape, dtype, inverse):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype), rng(vector_shape, dtype),)
|
||||
jnp_fn = lambda q, v: jsp_Rotation.from_quat(q).apply(v, inverse=inverse)
|
||||
np_fn = lambda q, v: osp_Rotation.from_quat(q).apply(v, inverse=inverse).astype(dtype) # HACK
|
||||
tol = 5e-2 if jtu.device_under_test() == 'tpu' else 1e-4
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
seq=['xyz', 'zyx', 'XYZ', 'ZYX'],
|
||||
degrees=[True, False],
|
||||
)
|
||||
def testRotationAsEuler(self, shape, dtype, seq, degrees):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees)
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees).astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
)
|
||||
def testRotationAsMatrix(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_matrix()
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).as_matrix().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
)
|
||||
def testRotationAsMrp(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_mrp()
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).as_mrp().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
degrees=[True, False],
|
||||
)
|
||||
def testRotationAsRotvec(self, shape, dtype, degrees):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_rotvec(degrees=degrees)
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).as_rotvec(degrees=degrees).astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
)
|
||||
def testRotationAsQuat(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat()
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).as_quat().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(num_samples, 4)],
|
||||
other_shape=[(num_samples, 4)],
|
||||
)
|
||||
def testRotationConcatenate(self, shape, other_shape, dtype):
|
||||
if scipy_version < (1, 8, 0):
|
||||
self.skipTest("Scipy 1.8.0 needed for concatenate.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype),)
|
||||
jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_quat()
|
||||
np_fn = lambda q, o: osp_Rotation.concatenate([osp_Rotation.from_quat(q), osp_Rotation.from_quat(o)]).as_quat().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(10, 4)],
|
||||
indexer=[slice(1, 5), slice(0), slice(-5, -3)],
|
||||
)
|
||||
def testRotationGetItem(self, shape, dtype, indexer):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q)[indexer].as_quat()
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q)[indexer].as_quat().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
size=[1, num_samples],
|
||||
seq=['x', 'xy', 'xyz', 'XYZ'],
|
||||
degrees=[True, False],
|
||||
)
|
||||
def testRotationFromEuler(self, size, dtype, seq, degrees):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
shape = (size, len(seq))
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda a: jsp_Rotation.from_euler(seq, a, degrees).as_rotvec()
|
||||
np_fn = lambda a: osp_Rotation.from_euler(seq, a, degrees).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(3, 3), (num_samples, 3, 3)],
|
||||
)
|
||||
def testRotationFromMatrix(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda m: jsp_Rotation.from_matrix(m).as_rotvec()
|
||||
np_fn = lambda m: osp_Rotation.from_matrix(m).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(3,), (num_samples, 3)],
|
||||
)
|
||||
def testRotationFromMrp(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda m: jsp_Rotation.from_mrp(m).as_rotvec()
|
||||
np_fn = lambda m: osp_Rotation.from_mrp(m).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(3,), (num_samples, 3)],
|
||||
)
|
||||
def testRotationFromRotvec(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda r: jsp_Rotation.from_rotvec(r).as_quat()
|
||||
np_fn = lambda r: osp_Rotation.from_rotvec(r).as_quat().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
num=[None],
|
||||
)
|
||||
def testRotationIdentity(self, num, dtype):
|
||||
args_maker = lambda: (num,)
|
||||
jnp_fn = lambda n: jsp_Rotation.identity(n, dtype).as_quat()
|
||||
np_fn = lambda n: osp_Rotation.identity(n).as_quat().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
)
|
||||
def testRotationMagnitude(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).magnitude()
|
||||
np_fn = lambda q: jnp.array(osp_Rotation.from_quat(q).magnitude(), dtype=dtype)
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(num_samples, 4)],
|
||||
rng_weights =[True, False],
|
||||
)
|
||||
def testRotationMean(self, shape, dtype, rng_weights):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype), jnp.abs(rng(shape[0], dtype)) if rng_weights else None)
|
||||
jnp_fn = lambda q, w: jsp_Rotation.from_quat(q).mean(w).as_rotvec()
|
||||
np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
|
||||
tol = 5e-3 if jtu.device_under_test() == 'tpu' else 1e-4
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
other_shape=[(4,), (num_samples, 4)],
|
||||
)
|
||||
def testRotationMultiply(self, shape, other_shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype))
|
||||
jnp_fn = lambda q, o: (jsp_Rotation.from_quat(q) * jsp_Rotation.from_quat(o)).as_rotvec()
|
||||
np_fn = lambda q, o: (osp_Rotation.from_quat(q) * osp_Rotation.from_quat(o)).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
)
|
||||
def testRotationInv(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_quat()
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).inv().as_quat().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(num_samples, 4)],
|
||||
)
|
||||
def testRotationLen(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: len(jsp_Rotation.from_quat(q))
|
||||
np_fn = lambda q: len(osp_Rotation.from_quat(q))
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(4,), (num_samples, 4)],
|
||||
)
|
||||
def testRotationSingle(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
jnp_fn = lambda q: jsp_Rotation.from_quat(q).single
|
||||
np_fn = lambda q: osp_Rotation.from_quat(q).single
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
shape=[(num_samples, 4)],
|
||||
compute_times=[0., onp.zeros(1), onp.zeros(2)],
|
||||
)
|
||||
def testSlerp(self, shape, dtype, compute_times):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype),)
|
||||
times = jnp.arange(shape[0], dtype=dtype)
|
||||
jnp_fn = lambda q: jsp_Slerp.init(times, jsp_Rotation.from_quat(q))(compute_times).as_rotvec()
|
||||
np_fn = lambda q: osp_Slerp(times, osp_Rotation.from_quat(q))(compute_times).as_rotvec().astype(dtype) # HACK
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user