[jax2tf] First draft for the conversion of FFTs. (#3871)

* [jax2tf] First draft for the conversion of FFTs.

Co-authored-by: Benjamin Chetioui <bchetioui@google.com>
This commit is contained in:
Benjamin Chetioui 2020-08-05 15:41:43 +02:00 committed by GitHub
parent 60db0a0737
commit 15a9a70bb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 1 deletions

View File

@ -39,6 +39,8 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla
from jaxlib import xla_client
import numpy as np
import tensorflow as tf # type: ignore[import]
@ -372,7 +374,7 @@ tf_not_yet_impl = [
lax_linalg.lu_p, lax_linalg.svd_p,
lax_linalg.triangular_solve_p,
lax_fft.fft_p, lax.igamma_grad_a_p,
lax.igamma_grad_a_p,
random.random_gamma_p,
lax.random_gamma_grad_p,
@ -1187,6 +1189,27 @@ def _sort(*operand: TfVal, dimension: int, is_stable: bool, num_keys: int) -> Tu
tf_impl[lax.sort_p] = _sort
def _fft(x, fft_type, fft_lengths):
shape = x.shape
assert len(fft_lengths) <= len(shape)
if ((fft_type == xla_client.FftType.IRFFT and
fft_lengths != shape[-len(fft_lengths):-1] + ((shape[-1] - 1) * 2,)) or
(fft_type != xla_client.FftType.IRFFT and
fft_lengths != shape[-len(fft_lengths):])):
raise NotImplementedError(f"Unsupported fft_lengths={fft_lengths} for fft_type={fft_type} of array with shape={shape}.")
tf_funcs = {xla_client.FftType.FFT: [tf.signal.fft, tf.signal.fft2d,
tf.signal.fft3d],
xla_client.FftType.IFFT: [tf.signal.ifft, tf.signal.ifft2d,
tf.signal.ifft3d],
xla_client.FftType.RFFT: [tf.signal.rfft, tf.signal.rfft2d,
tf.signal.rfft3d],
xla_client.FftType.IRFFT: [tf.signal.irfft, tf.signal.irfft2d,
tf.signal.irfft3d]}
return tf_funcs[fft_type][len(fft_lengths) - 1](x)
tf_impl[lax_fft.fft_p] = _fft
def _qr(operand, full_matrices):
return tf.linalg.qr(operand, full_matrices=full_matrices)

View File

@ -27,6 +27,8 @@ from jax import lax
from jax import lax_linalg
from jax import numpy as jnp
from jaxlib import xla_client
import numpy as np
FLAGS = config.FLAGS
@ -416,6 +418,38 @@ lax_linalg_qr = tuple(
for full_matrices in [False, True]
)
def _fft_harness_gen(nb_axes):
def _fft_rng_factory(dtype):
_all_integers = jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + jtu.dtypes.boolean
# For integer types, use small values to keep the errors small
if dtype in _all_integers:
return jtu.rand_small
else:
return jtu.rand_default
return tuple(
Harness(f"{nb_axes}d_shape={jtu.format_shape_dtype_string(shape, dtype)}_ffttype={fft_type}_fftlengths={fft_lengths}",
lax.lax_fft.fft,
[RandArg(shape, dtype), StaticArg(fft_type), StaticArg(fft_lengths)],
rng_factory=_fft_rng_factory(dtype),
shape=shape,
dtype=dtype,
fft_type=fft_type,
fft_lengths=fft_lengths)
for dtype in jtu.dtypes.all
for shape in filter(lambda x: len(x) >= nb_axes,
[(10,), (12, 13), (14, 15, 16), (14, 15, 16, 17)])
for fft_type, fft_lengths in [(xla_client.FftType.FFT, shape[-nb_axes:]),
(xla_client.FftType.IFFT, shape[-nb_axes:]),
(xla_client.FftType.RFFT, shape[-nb_axes:]),
(xla_client.FftType.IRFFT,
shape[-nb_axes:-1] + ((shape[-1] - 1) * 2,))]
if not (dtype in jtu.dtypes.complex and fft_type == xla_client.FftType.RFFT)
)
lax_fft = tuple(_fft_harness_gen(1) + _fft_harness_gen(2) + _fft_harness_gen(3) +
_fft_harness_gen(4))
lax_slice = tuple(
Harness(f"_shape={shape}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore
lax.slice,

View File

@ -27,6 +27,7 @@ from jax.config import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
from jax.interpreters import xla
import numpy as np
import tensorflow as tf # type: ignore[import]
@ -154,6 +155,27 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
raise unittest.SkipTest("GPU tests are running TF on CPU")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_fft)
def test_fft(self, harness: primitive_harness.Harness):
if len(harness.params["fft_lengths"]) > 3:
with self.assertRaisesRegex(RuntimeError, "FFT only supports ranks 1-3"):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
elif harness.params["dtype"] is dtypes.bfloat16:
raise unittest.SkipTest("bfloat16 support not implemented")
elif jtu.device_under_test() == "tpu" and len(harness.params["fft_lengths"]) > 1:
# TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX
with self.assertRaisesRegex(RuntimeError, "only 1D FFT is currently supported."):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
else:
tol = None
if jtu.device_under_test() == "gpu":
if harness.params["dtype"] in jtu.dtypes.boolean:
tol = 0.01
else:
tol = 1e-3
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=tol, rtol=tol)
@primitive_harness.parameterized(primitive_harness.lax_linalg_qr)
def test_qr(self, harness: primitive_harness.Harness):
# See jax.lib.lapack.geqrf for the list of compatible types