Merge pull request #6232 from shoyer:irfft-transpose-fix

PiperOrigin-RevId: 365567972
This commit is contained in:
jax authors 2021-03-29 07:17:01 -07:00
commit 112ae410fa
2 changed files with 19 additions and 1 deletions

View File

@ -114,7 +114,9 @@ def _irfft_transpose(t, fft_lengths):
scale = 1 / prod(fft_lengths)
out = scale * mask * x
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
return out
# Use JAX's convention for complex gradients
# https://github.com/google/jax/issues/6223#issuecomment-807740707
return lax.conj(out)
def fft_transpose_rule(t, operand, fft_type, fft_lengths):
if fft_type == xla_client.FftType.RFFT:

View File

@ -20,6 +20,7 @@ import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
@ -116,6 +117,21 @@ class FftTest(jtu.JaxTestCase):
tol = 0.15
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
def testIrfftTranspose(self):
# regression test for https://github.com/google/jax/issues/6223
def build_matrix(linear_func, size):
return jax.vmap(linear_func)(jnp.eye(size, size))
def func(x):
return jnp.fft.irfft(jnp.concatenate([jnp.zeros(1), x[:2] + 1j*x[2:]]))
def func_transpose(x):
return jax.linear_transpose(func, x)(x)[0]
matrix = build_matrix(func, 4)
matrix2 = build_matrix(func_transpose, 4).T
self.assertAllClose(matrix, matrix2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
"inverse": inverse, "real": real}