mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add jit to lax.fft
The main motivation here is ensuring that FFTs are always marked in profiler results, which is not necessarily the case where running on TPUs. I would jit decorate the user facing functions in jax.numpy.fft, but these functions also accept parameters as lists, e.g., for axes, which are mutable and hence not valid as direct input into jit decorated functions. This might be worth doing, but would be a breaking change.
This commit is contained in:
parent
6ad9291a9a
commit
22943ef839
@ -43,6 +43,7 @@ def _promote_to_real(arg):
|
||||
dtype = dtypes.result_type(arg, np.float32)
|
||||
return lax.convert_element_type(arg, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def fft(x, fft_type, fft_lengths):
|
||||
if fft_type == xla_client.FftType.RFFT:
|
||||
if np.iscomplexobj(x):
|
||||
|
@ -78,7 +78,7 @@ def _fft_core(func_name, fft_type, a, s, axes, norm):
|
||||
else:
|
||||
s = [a.shape[axis] for axis in axes]
|
||||
|
||||
transformed = lax.fft(a, fft_type, s)
|
||||
transformed = lax.fft(a, fft_type, tuple(s))
|
||||
|
||||
if orig_axes is not None:
|
||||
transformed = jnp.moveaxis(transformed, axes, orig_axes)
|
||||
|
@ -628,7 +628,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
ndims = len(shape)
|
||||
axes = range(ndims - fft_ndims, ndims)
|
||||
fft_lengths = [shape[axis] for axis in axes]
|
||||
fft_lengths = tuple(shape[axis] for axis in axes)
|
||||
op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
|
||||
self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user