rocm_jax/jaxlib/cpu/ducc_fft.fbs
George Necula b9c0658fcf Add support for dynamic shapes to jax.fft.
The idea is that we take all the values that can contain dimension sizes
from the descriptor (shape, strides_in, strides_out) and we pass them as
1-d tensor operands. We also pass as an operand the output_shape, so that
we can use the hlo.CustomCallOp `indices_of_output_shapes` attribute to
tell the shape refinement how to compute the shape of the result.

We keep the old descriptor and the ducc_fft registration for the old
C++ custom targets for backwards compatibility (for 6 months). That behavior
is tested by back_compat_test.py.

The one downside of this implementation is that it moves some of the
ducc-specific logic from ducc_fft.py (in jaxlib) into fft.py (in jax). This
was necessary because that code computes with dimensions that are now
dynamic. In JAX we have support for evaluating dynamic shapes and turning
them into 1-d tensors.

Also added backwards compatibility test for dynamic_ducc_fft and kept the
old test for ducc_fft.

PiperOrigin-RevId: 541168692
2023-06-17 04:50:54 -07:00

49 lines
1.1 KiB
Plaintext

/* Copyright 2020 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
namespace jax;
enum DuccFftDtype : byte {
COMPLEX64 = 0,
COMPLEX128 = 1,
}
enum DuccFftType : byte {
C2C = 0,
C2R = 1,
R2C = 2,
}
table DuccFftDescriptor {
dtype:DuccFftDtype;
fft_type:DuccFftType;
shape:[uint64];
strides_in:[uint64];
strides_out:[uint64];
axes:[uint32];
forward:bool;
scale:double;
}
table DynamicDuccFftDescriptor {
ndims:uint32;
dtype:DuccFftDtype;
fft_type:DuccFftType;
axes:[uint32];
forward:bool;
}
root_type DuccFftDescriptor;