Skip test outside x64

This commit is contained in:
Jake VanderPlas 2022-04-04 16:00:18 -07:00
parent 1246b6fc73
commit 5a96c0cb18

View File

@ -112,6 +112,8 @@ class FftTest(jtu.JaxTestCase):
@unittest.skipIf(jax._src.lib.xla_extension_version < 63,
"Test fails for jaxlib <= 0.3.2")
def testLaxIrfftDoesNotMutateInputs(self, dtype):
if dtype == np.float64 and not config.x64_enabled:
raise self.skipTest("float64 requires jax_enable_x64=true")
x = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtype) * (1+1j)
y = np.asarray(jnp.fft.irfft2(x))
z = np.asarray(jnp.fft.irfft2(x))