mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] First draft for the conversion of FFTs. (#3871)
* [jax2tf] First draft for the conversion of FFTs. Co-authored-by: Benjamin Chetioui <bchetioui@google.com>
This commit is contained in:
parent
60db0a0737
commit
15a9a70bb7
@ -39,6 +39,8 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
@ -372,7 +374,7 @@ tf_not_yet_impl = [
|
||||
lax_linalg.lu_p, lax_linalg.svd_p,
|
||||
lax_linalg.triangular_solve_p,
|
||||
|
||||
lax_fft.fft_p, lax.igamma_grad_a_p,
|
||||
lax.igamma_grad_a_p,
|
||||
random.random_gamma_p,
|
||||
lax.random_gamma_grad_p,
|
||||
|
||||
@ -1187,6 +1189,27 @@ def _sort(*operand: TfVal, dimension: int, is_stable: bool, num_keys: int) -> Tu
|
||||
|
||||
tf_impl[lax.sort_p] = _sort
|
||||
|
||||
def _fft(x, fft_type, fft_lengths):
|
||||
shape = x.shape
|
||||
assert len(fft_lengths) <= len(shape)
|
||||
if ((fft_type == xla_client.FftType.IRFFT and
|
||||
fft_lengths != shape[-len(fft_lengths):-1] + ((shape[-1] - 1) * 2,)) or
|
||||
(fft_type != xla_client.FftType.IRFFT and
|
||||
fft_lengths != shape[-len(fft_lengths):])):
|
||||
raise NotImplementedError(f"Unsupported fft_lengths={fft_lengths} for fft_type={fft_type} of array with shape={shape}.")
|
||||
tf_funcs = {xla_client.FftType.FFT: [tf.signal.fft, tf.signal.fft2d,
|
||||
tf.signal.fft3d],
|
||||
xla_client.FftType.IFFT: [tf.signal.ifft, tf.signal.ifft2d,
|
||||
tf.signal.ifft3d],
|
||||
xla_client.FftType.RFFT: [tf.signal.rfft, tf.signal.rfft2d,
|
||||
tf.signal.rfft3d],
|
||||
xla_client.FftType.IRFFT: [tf.signal.irfft, tf.signal.irfft2d,
|
||||
tf.signal.irfft3d]}
|
||||
|
||||
return tf_funcs[fft_type][len(fft_lengths) - 1](x)
|
||||
|
||||
tf_impl[lax_fft.fft_p] = _fft
|
||||
|
||||
def _qr(operand, full_matrices):
|
||||
return tf.linalg.qr(operand, full_matrices=full_matrices)
|
||||
|
||||
|
@ -27,6 +27,8 @@ from jax import lax
|
||||
from jax import lax_linalg
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
import numpy as np
|
||||
|
||||
FLAGS = config.FLAGS
|
||||
@ -416,6 +418,38 @@ lax_linalg_qr = tuple(
|
||||
for full_matrices in [False, True]
|
||||
)
|
||||
|
||||
def _fft_harness_gen(nb_axes):
|
||||
def _fft_rng_factory(dtype):
|
||||
_all_integers = jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + jtu.dtypes.boolean
|
||||
# For integer types, use small values to keep the errors small
|
||||
if dtype in _all_integers:
|
||||
return jtu.rand_small
|
||||
else:
|
||||
return jtu.rand_default
|
||||
|
||||
return tuple(
|
||||
Harness(f"{nb_axes}d_shape={jtu.format_shape_dtype_string(shape, dtype)}_ffttype={fft_type}_fftlengths={fft_lengths}",
|
||||
lax.lax_fft.fft,
|
||||
[RandArg(shape, dtype), StaticArg(fft_type), StaticArg(fft_lengths)],
|
||||
rng_factory=_fft_rng_factory(dtype),
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
fft_type=fft_type,
|
||||
fft_lengths=fft_lengths)
|
||||
for dtype in jtu.dtypes.all
|
||||
for shape in filter(lambda x: len(x) >= nb_axes,
|
||||
[(10,), (12, 13), (14, 15, 16), (14, 15, 16, 17)])
|
||||
for fft_type, fft_lengths in [(xla_client.FftType.FFT, shape[-nb_axes:]),
|
||||
(xla_client.FftType.IFFT, shape[-nb_axes:]),
|
||||
(xla_client.FftType.RFFT, shape[-nb_axes:]),
|
||||
(xla_client.FftType.IRFFT,
|
||||
shape[-nb_axes:-1] + ((shape[-1] - 1) * 2,))]
|
||||
if not (dtype in jtu.dtypes.complex and fft_type == xla_client.FftType.RFFT)
|
||||
)
|
||||
|
||||
lax_fft = tuple(_fft_harness_gen(1) + _fft_harness_gen(2) + _fft_harness_gen(3) +
|
||||
_fft_harness_gen(4))
|
||||
|
||||
lax_slice = tuple(
|
||||
Harness(f"_shape={shape}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore
|
||||
lax.slice,
|
||||
|
@ -27,6 +27,7 @@ from jax.config import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.interpreters import xla
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
@ -154,6 +155,27 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
raise unittest.SkipTest("GPU tests are running TF on CPU")
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_fft)
|
||||
def test_fft(self, harness: primitive_harness.Harness):
|
||||
if len(harness.params["fft_lengths"]) > 3:
|
||||
with self.assertRaisesRegex(RuntimeError, "FFT only supports ranks 1-3"):
|
||||
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
|
||||
elif harness.params["dtype"] is dtypes.bfloat16:
|
||||
raise unittest.SkipTest("bfloat16 support not implemented")
|
||||
elif jtu.device_under_test() == "tpu" and len(harness.params["fft_lengths"]) > 1:
|
||||
# TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX
|
||||
with self.assertRaisesRegex(RuntimeError, "only 1D FFT is currently supported."):
|
||||
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
|
||||
else:
|
||||
tol = None
|
||||
if jtu.device_under_test() == "gpu":
|
||||
if harness.params["dtype"] in jtu.dtypes.boolean:
|
||||
tol = 0.01
|
||||
else:
|
||||
tol = 1e-3
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_linalg_qr)
|
||||
def test_qr(self, harness: primitive_harness.Harness):
|
||||
# See jax.lib.lapack.geqrf for the list of compatible types
|
||||
|
Loading…
x
Reference in New Issue
Block a user