mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Increase precision in LaxBackedScipySpatialTransformTests.testRotationApply
Otherwise the test fails due to small numerical differences. PiperOrigin-RevId: 538921774
This commit is contained in:
parent
4c02f2c748
commit
d8571cfb6b
@ -14,6 +14,8 @@
|
|||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
|
||||||
|
import jax
|
||||||
|
|
||||||
import scipy.version
|
import scipy.version
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax.scipy.spatial.transform import Rotation as jsp_Rotation
|
from jax.scipy.spatial.transform import Rotation as jsp_Rotation
|
||||||
@ -43,6 +45,7 @@ class LaxBackedScipySpatialTransformTests(jtu.JaxTestCase):
|
|||||||
vector_shape=[(3,), (num_samples, 3)],
|
vector_shape=[(3,), (num_samples, 3)],
|
||||||
inverse=[True, False],
|
inverse=[True, False],
|
||||||
)
|
)
|
||||||
|
@jax.default_matmul_precision("float32")
|
||||||
def testRotationApply(self, shape, vector_shape, dtype, inverse):
|
def testRotationApply(self, shape, vector_shape, dtype, inverse):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
args_maker = lambda: (rng(shape, dtype), rng(vector_shape, dtype),)
|
args_maker = lambda: (rng(shape, dtype), rng(vector_shape, dtype),)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user