mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

The idea is that we take all the values that can contain dimension sizes from the descriptor (shape, strides_in, strides_out) and we pass them as 1-d tensor operands. We also pass as an operand the output_shape, so that we can use the hlo.CustomCallOp `indices_of_output_shapes` attribute to tell the shape refinement how to compute the shape of the result. We keep the old descriptor and the ducc_fft registration for the old C++ custom targets for backwards compatibility (for 6 months). That behavior is tested by back_compat_test.py. The one downside of this implementation is that it moves some of the ducc-specific logic from ducc_fft.py (in jaxlib) into fft.py (in jax). This was necessary because that code computes with dimensions that are now dynamic. In JAX we have support for evaluating dynamic shapes and turning them into 1-d tensors. Also added backwards compatibility test for dynamic_ducc_fft and kept the old test for ducc_fft. PiperOrigin-RevId: 541168692