Add jit to lax.fft

The main motivation here is ensuring that FFTs are always marked in
profiler results, which is not necessarily the case where running on
TPUs.

I would jit decorate the user facing functions in jax.numpy.fft, but
these functions also accept parameters as lists, e.g., for axes, which
are mutable and hence not valid as direct input into jit decorated
functions. This might be worth doing, but would be a breaking change.
This commit is contained in:
Stephan Hoyer 2021-08-30 09:04:13 -07:00
parent 6ad9291a9a
commit 22943ef839
3 changed files with 3 additions and 2 deletions

View File

@ -43,6 +43,7 @@ def _promote_to_real(arg):
dtype = dtypes.result_type(arg, np.float32)
return lax.convert_element_type(arg, dtype)
@partial(jit, static_argnums=(1, 2))
def fft(x, fft_type, fft_lengths):
if fft_type == xla_client.FftType.RFFT:
if np.iscomplexobj(x):

View File

@ -78,7 +78,7 @@ def _fft_core(func_name, fft_type, a, s, axes, norm):
else:
s = [a.shape[axis] for axis in axes]
transformed = lax.fft(a, fft_type, s)
transformed = lax.fft(a, fft_type, tuple(s))
if orig_axes is not None:
transformed = jnp.moveaxis(transformed, axes, orig_axes)

View File

@ -628,7 +628,7 @@ class LaxVmapTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
ndims = len(shape)
axes = range(ndims - fft_ndims, ndims)
fft_lengths = [shape[axis] for axis in axes]
fft_lengths = tuple(shape[axis] for axis in axes)
op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)