[GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.

Fixes https://github.com/google/jax/issues/9946

PiperOrigin-RevId: 439414091
This commit is contained in:
Peter Hawkins 2022-04-04 14:34:00 -07:00 committed by jax authors
parent 6825f654b1
commit 71a5eb263b
2 changed files with 12 additions and 1 deletions

View File

@ -23,7 +23,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.tree_util.tree_multimap` is deprecated. Use {func}`jax.tree_util.tree_map` instead ({jax-issue}`#5746`).
## jaxlib 0.3.3 (Unreleased)
* Bug fixes
* Fixed a bug where double-precision complex-to-real IRFFTs would mutate their
input buffers on GPU ({jax-issue}`#9946`).
## jax 0.3.4 (March 18, 2022)
* [GitHub

View File

@ -14,6 +14,7 @@
import itertools
import unittest
import numpy as np
@ -107,6 +108,14 @@ class FftTest(jtu.JaxTestCase):
self.assertAllClose(np.fft.fft(x).astype(np.complex64),
lax.fft(x, "FFT", fft_lengths=(10,)))
@parameterized.parameters((np.float32,), (np.float64,))
@unittest.skipIf(jax._src.lib.xla_extension_version < 63,
"Test fails for jaxlib <= 0.3.2")
def testLaxIrfftDoesNotMutateInputs(self, dtype):
x = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtype) * (1+1j)
y = np.asarray(jnp.fft.irfft2(x))
z = np.asarray(jnp.fft.irfft2(x))
self.assertAllClose(y, z)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}_shape={}_axes={}_s={}_norm={}".format(