George Necula b9c0658fcf Add support for dynamic shapes to jax.fft.
The idea is that we take all the values that can contain dimension sizes
from the descriptor (shape, strides_in, strides_out) and we pass them as
1-d tensor operands. We also pass as an operand the output_shape, so that
we can use the hlo.CustomCallOp `indices_of_output_shapes` attribute to
tell the shape refinement how to compute the shape of the result.

We keep the old descriptor and the ducc_fft registration for the old
C++ custom targets for backwards compatibility (for 6 months). That behavior
is tested by back_compat_test.py.

The one downside of this implementation is that it moves some of the
ducc-specific logic from ducc_fft.py (in jaxlib) into fft.py (in jax). This
was necessary because that code computes with dimensions that are now
dynamic. In JAX we have support for evaluating dynamic shapes and turning
them into 1-d tensors.

Also added backwards compatibility test for dynamic_ducc_fft and kept the
old test for ducc_fft.

PiperOrigin-RevId: 541168692
2023-06-17 04:50:54 -07:00
..