mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Skip test outside x64
This commit is contained in:
parent
1246b6fc73
commit
5a96c0cb18
@ -112,6 +112,8 @@ class FftTest(jtu.JaxTestCase):
|
||||
@unittest.skipIf(jax._src.lib.xla_extension_version < 63,
|
||||
"Test fails for jaxlib <= 0.3.2")
|
||||
def testLaxIrfftDoesNotMutateInputs(self, dtype):
|
||||
if dtype == np.float64 and not config.x64_enabled:
|
||||
raise self.skipTest("float64 requires jax_enable_x64=true")
|
||||
x = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtype) * (1+1j)
|
||||
y = np.asarray(jnp.fft.irfft2(x))
|
||||
z = np.asarray(jnp.fft.irfft2(x))
|
||||
|
Loading…
x
Reference in New Issue
Block a user