mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #6232 from shoyer:irfft-transpose-fix
PiperOrigin-RevId: 365567972
This commit is contained in:
commit
112ae410fa
@ -114,7 +114,9 @@ def _irfft_transpose(t, fft_lengths):
|
||||
scale = 1 / prod(fft_lengths)
|
||||
out = scale * mask * x
|
||||
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
|
||||
return out
|
||||
# Use JAX's convention for complex gradients
|
||||
# https://github.com/google/jax/issues/6223#issuecomment-807740707
|
||||
return lax.conj(out)
|
||||
|
||||
def fft_transpose_rule(t, operand, fft_type, fft_lengths):
|
||||
if fft_type == xla_client.FftType.RFFT:
|
||||
|
@ -20,6 +20,7 @@ import numpy as np
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
@ -116,6 +117,21 @@ class FftTest(jtu.JaxTestCase):
|
||||
tol = 0.15
|
||||
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
|
||||
def testIrfftTranspose(self):
|
||||
# regression test for https://github.com/google/jax/issues/6223
|
||||
def build_matrix(linear_func, size):
|
||||
return jax.vmap(linear_func)(jnp.eye(size, size))
|
||||
|
||||
def func(x):
|
||||
return jnp.fft.irfft(jnp.concatenate([jnp.zeros(1), x[:2] + 1j*x[2:]]))
|
||||
|
||||
def func_transpose(x):
|
||||
return jax.linear_transpose(func, x)(x)[0]
|
||||
|
||||
matrix = build_matrix(func, 4)
|
||||
matrix2 = build_matrix(func_transpose, 4).T
|
||||
self.assertAllClose(matrix, matrix2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}".format(inverse, real),
|
||||
"inverse": inverse, "real": real}
|
||||
|
Loading…
x
Reference in New Issue
Block a user