Support string fft_type values in lax.fft.

This commit is contained in:
Peter Hawkins 2022-02-02 13:45:15 -05:00
parent e29bfce85c
commit 84bccb2420
2 changed files with 30 additions and 3 deletions

View File

@ -14,6 +14,7 @@
from functools import partial
from typing import Union, Sequence
import numpy as np
@ -43,9 +44,28 @@ def _promote_to_real(arg):
dtype = dtypes.result_type(arg, np.float32)
return lax.convert_element_type(arg, dtype)
def _str_to_fft_type(s: str) -> xla_client.FftType:
if s == "FFT":
return xla_client.FftType.FFT
elif s == "IFFT":
return xla_client.FftType.IFFT
elif s == "RFFT":
return xla_client.FftType.RFFT
elif s == "IRFFT":
return xla_client.FftType.IRFFT
else:
raise ValueError(f"Unknown FFT type '{s}'")
@partial(jit, static_argnums=(1, 2))
def fft(x, fft_type, fft_lengths):
if fft_type == xla_client.FftType.RFFT:
def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int]):
if isinstance(fft_type, str):
typ = _str_to_fft_type(fft_type)
elif isinstance(fft_type, xla_client.FftType):
typ = fft_type
else:
raise TypeError(f"Unknown FFT type value '{fft_type}'")
if typ == xla_client.FftType.RFFT:
if np.iscomplexobj(x):
raise ValueError("only real valued inputs supported for rfft")
x = _promote_to_real(x)
@ -55,7 +75,7 @@ def fft(x, fft_type, fft_lengths):
# XLA FFT doesn't support 0-rank.
return x
fft_lengths = tuple(fft_lengths)
return fft_p.bind(x, fft_type=fft_type, fft_lengths=fft_lengths)
return fft_p.bind(x, fft_type=typ, fft_lengths=fft_lengths)
def fft_impl(x, fft_type, fft_lengths):
return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths)

View File

@ -102,6 +102,13 @@ class FftTest(jtu.JaxTestCase):
with self.assertRaises(NotImplementedError):
func()
def testLaxFftAcceptsStringTypes(self):
rng = jtu.rand_default(self.rng())
x = rng((10,), np.complex64)
self.assertAllClose(np.fft.fft(x).astype(np.complex64),
lax.fft(x, "FFT", fft_lengths=(10,)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}_shape={}_axes={}_s={}_norm={}".format(
inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes, s, norm),