mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #9815 from eelregit:fft_norm_dtype_patch
PiperOrigin-RevId: 434772719
This commit is contained in:
commit
98ad016794
@ -86,7 +86,8 @@ def _fft_core(func_name, fft_type, a, s, axes, norm):
|
||||
s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
|
||||
else:
|
||||
s = [a.shape[axis] for axis in axes]
|
||||
transformed = lax.fft(a, fft_type, tuple(s)) * _fft_norm(jnp.array(s), func_name, norm)
|
||||
transformed = lax.fft(a, fft_type, tuple(s))
|
||||
transformed *= _fft_norm(jnp.array(s, dtype=transformed.real.dtype), func_name, norm)
|
||||
|
||||
if orig_axes is not None:
|
||||
transformed = jnp.moveaxis(transformed, axes, orig_axes)
|
||||
|
@ -138,6 +138,11 @@ class FftTest(jtu.JaxTestCase):
|
||||
tol = 0.15
|
||||
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
|
||||
# check dtypes
|
||||
dtype = jnp_fn(rng(shape, dtype)).dtype
|
||||
expected_dtype = jnp.promote_types(float if inverse and real else complex, dtype)
|
||||
self.assertEqual(dtype, expected_dtype)
|
||||
|
||||
def testIrfftTranspose(self):
|
||||
# regression test for https://github.com/google/jax/issues/6223
|
||||
def build_matrix(linear_func, size):
|
||||
|
Loading…
x
Reference in New Issue
Block a user