Increase precision in LaxBackedScipySpatialTransformTests.testRotationApply

Otherwise the test fails due to small numerical differences.

PiperOrigin-RevId: 538921774
This commit is contained in:
Skye Wanderman-Milne 2023-06-08 16:30:15 -07:00 committed by jax authors
parent 4c02f2c748
commit d8571cfb6b

View File

@ -14,6 +14,8 @@
from absl.testing import absltest
import jax
import scipy.version
from jax._src import test_util as jtu
from jax.scipy.spatial.transform import Rotation as jsp_Rotation
@ -43,6 +45,7 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
vector_shape=[(3,), (num_samples, 3)],
inverse=[True, False],
)
@jax.default_matmul_precision("float32")
def testRotationApply(self, shape, vector_shape, dtype, inverse):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), rng(vector_shape, dtype),)