diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index 61de11820..996ad3e83 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -14,6 +14,8 @@ from absl.testing import absltest +import jax + import scipy.version from jax._src import test_util as jtu from jax.scipy.spatial.transform import Rotation as jsp_Rotation @@ -43,6 +45,7 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase): vector_shape=[(3,), (num_samples, 3)], inverse=[True, False], ) + @jax.default_matmul_precision("float32") def testRotationApply(self, shape, vector_shape, dtype, inverse): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), rng(vector_shape, dtype),)