mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.
Fixes https://github.com/google/jax/issues/9946 PiperOrigin-RevId: 439414091
This commit is contained in:
parent
6825f654b1
commit
71a5eb263b
@ -23,7 +23,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* {func}`jax.tree_util.tree_multimap` is deprecated. Use {func}`jax.tree_util.tree_map` instead ({jax-issue}`#5746`).
|
||||
|
||||
## jaxlib 0.3.3 (Unreleased)
|
||||
|
||||
* Bug fixes
|
||||
* Fixed a bug where double-precision complex-to-real IRFFTs would mutate their
|
||||
input buffers on GPU ({jax-issue}`#9946`).
|
||||
|
||||
## jax 0.3.4 (March 18, 2022)
|
||||
* [GitHub
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -107,6 +108,14 @@ class FftTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(np.fft.fft(x).astype(np.complex64),
|
||||
lax.fft(x, "FFT", fft_lengths=(10,)))
|
||||
|
||||
@parameterized.parameters((np.float32,), (np.float64,))
|
||||
@unittest.skipIf(jax._src.lib.xla_extension_version < 63,
|
||||
"Test fails for jaxlib <= 0.3.2")
|
||||
def testLaxIrfftDoesNotMutateInputs(self, dtype):
|
||||
x = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtype) * (1+1j)
|
||||
y = np.asarray(jnp.fft.irfft2(x))
|
||||
z = np.asarray(jnp.fft.irfft2(x))
|
||||
self.assertAllClose(y, z)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}_shape={}_axes={}_s={}_norm={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user