mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support string fft_type values in lax.fft.
This commit is contained in:
parent
e29bfce85c
commit
84bccb2420
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
from typing import Union, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -43,9 +44,28 @@ def _promote_to_real(arg):
|
||||
dtype = dtypes.result_type(arg, np.float32)
|
||||
return lax.convert_element_type(arg, dtype)
|
||||
|
||||
def _str_to_fft_type(s: str) -> xla_client.FftType:
|
||||
if s == "FFT":
|
||||
return xla_client.FftType.FFT
|
||||
elif s == "IFFT":
|
||||
return xla_client.FftType.IFFT
|
||||
elif s == "RFFT":
|
||||
return xla_client.FftType.RFFT
|
||||
elif s == "IRFFT":
|
||||
return xla_client.FftType.IRFFT
|
||||
else:
|
||||
raise ValueError(f"Unknown FFT type '{s}'")
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def fft(x, fft_type, fft_lengths):
|
||||
if fft_type == xla_client.FftType.RFFT:
|
||||
def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int]):
|
||||
if isinstance(fft_type, str):
|
||||
typ = _str_to_fft_type(fft_type)
|
||||
elif isinstance(fft_type, xla_client.FftType):
|
||||
typ = fft_type
|
||||
else:
|
||||
raise TypeError(f"Unknown FFT type value '{fft_type}'")
|
||||
|
||||
if typ == xla_client.FftType.RFFT:
|
||||
if np.iscomplexobj(x):
|
||||
raise ValueError("only real valued inputs supported for rfft")
|
||||
x = _promote_to_real(x)
|
||||
@ -55,7 +75,7 @@ def fft(x, fft_type, fft_lengths):
|
||||
# XLA FFT doesn't support 0-rank.
|
||||
return x
|
||||
fft_lengths = tuple(fft_lengths)
|
||||
return fft_p.bind(x, fft_type=fft_type, fft_lengths=fft_lengths)
|
||||
return fft_p.bind(x, fft_type=typ, fft_lengths=fft_lengths)
|
||||
|
||||
def fft_impl(x, fft_type, fft_lengths):
|
||||
return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths)
|
||||
|
@ -102,6 +102,13 @@ class FftTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
func()
|
||||
|
||||
def testLaxFftAcceptsStringTypes(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng((10,), np.complex64)
|
||||
self.assertAllClose(np.fft.fft(x).astype(np.complex64),
|
||||
lax.fft(x, "FFT", fft_lengths=(10,)))
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}_shape={}_axes={}_s={}_norm={}".format(
|
||||
inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes, s, norm),
|
||||
|
Loading…
x
Reference in New Issue
Block a user