mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Increase precision in LaxBackedScipySpatialTransformTests.testRotationApply
Otherwise the test fails due to small numerical differences. PiperOrigin-RevId: 538921774
This commit is contained in:
parent
4c02f2c748
commit
d8571cfb6b
@ -14,6 +14,8 @@
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
|
||||
import scipy.version
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy.spatial.transform import Rotation as jsp_Rotation
|
||||
@ -43,6 +45,7 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
||||
vector_shape=[(3,), (num_samples, 3)],
|
||||
inverse=[True, False],
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testRotationApply(self, shape, vector_shape, dtype, inverse):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: (rng(shape, dtype), rng(vector_shape, dtype),)
|
||||
|
Loading…
x
Reference in New Issue
Block a user