diff --git a/CHANGELOG.md b/CHANGELOG.md index 08c13d082..12c6f8ac9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * {func}`jax.tree_util.tree_multimap` is deprecated. Use {func}`jax.tree_util.tree_map` instead ({jax-issue}`#5746`). ## jaxlib 0.3.3 (Unreleased) - +* Bug fixes + * Fixed a bug where double-precision complex-to-real IRFFTs would mutate their + input buffers on GPU ({jax-issue}`#9946`). ## jax 0.3.4 (March 18, 2022) * [GitHub diff --git a/tests/fft_test.py b/tests/fft_test.py index 8b15276f9..a08365958 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -14,6 +14,7 @@ import itertools +import unittest import numpy as np @@ -107,6 +108,14 @@ class FftTest(jtu.JaxTestCase): self.assertAllClose(np.fft.fft(x).astype(np.complex64), lax.fft(x, "FFT", fft_lengths=(10,))) + @parameterized.parameters((np.float32,), (np.float64,)) + @unittest.skipIf(jax._src.lib.xla_extension_version < 63, + "Test fails for jaxlib <= 0.3.2") + def testLaxIrfftDoesNotMutateInputs(self, dtype): + x = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtype) * (1+1j) + y = np.asarray(jnp.fft.irfft2(x)) + z = np.asarray(jnp.fft.irfft2(x)) + self.assertAllClose(y, z) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inverse={}_real={}_shape={}_axes={}_s={}_norm={}".format(