Remove old ducc_fft custom call.

Starting in June 2023 we have switched the CPU lowering for FFT to use
the new custom call dynamic_ducc_fft. We are now out of the backwards
compatibility window and we remove the old ducc_fft.

We need to keep dynamic_ducc_fft a little bit longer (May 2024).

PiperOrigin-RevId: 627981921
This commit is contained in:
George Necula 2024-04-25 00:28:19 -07:00 committed by jax authors
parent 53c7c3708b
commit 6bfbb4593a
7 changed files with 7 additions and 70 deletions

View File

@ -15,36 +15,6 @@
import datetime import datetime
from numpy import array, float32, complex64 from numpy import array, float32, complex64
# Pasted from the test output (see back_compat_test.py module docstring)
data_2023_03_17 = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['ducc_fft'],
serialized_date=datetime.date(2023, 3, 17),
inputs=(array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]], dtype=float32),),
expected_outputs=(array([[ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j],
[22.+0.j, -2.+2.j, -2.+0.j, -2.-2.j],
[38.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]], dtype=complex64),),
mlir_module_text=r"""
module @jit_func {
func.func public @main(%arg0: tensor<3x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<3x4xcomplex<f32>> {jax.result_info = ""}) {
%0 = call @fft(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xcomplex<f32>>
return %0 : tensor<3x4xcomplex<f32>>
}
func.func private @fft(%arg0: tensor<3x4xf32>) -> tensor<3x4xcomplex<f32>> {
%0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xcomplex<f32>>
%1 = stablehlo.constant dense<"0x18000000140024000000000008000C001000140007001800140000000000000154000000380000001C00000010000000000000000000F03F0000000001000000010000000200000004000000000000000100000000000000000000000200000004000000000000000100000000000000000000000200000003000000000000000400000000000000"> : tensor<136xui8>
%2 = stablehlo.custom_call @ducc_fft(%1, %0) {api_version = 2 : i32, operand_layouts = [dense<0> : tensor<1xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<136xui8>, tensor<3x4xcomplex<f32>>) -> tensor<3x4xcomplex<f32>>
return %2 : tensor<3x4xcomplex<f32>>
}
}
""",
mlir_module_serialized=b'ML\xefR\x03MLIRxxx-trunk\x00\x01\x1d\x05\x01\x05\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\x99s\x15\x01?\x07\x0b\x0f\x17\x0b\x0b\x0b\x0b\x0f\x13\x0b33\x0b\x0b\x0f\x0b\x13\x0b\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x035\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0bf\x04\x0b\x0b\x0b\x13/\x0f\x03\x15\x17\x17\x07\x17\x07\x17\x0b\x07\x13\x13\x02\xca\x05\x1f\x05\x13\x1d\x1b\x07\x17\x1d^\x03\x01\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\'\x07\x03\x03\x03\x15\x05\x1d\x03\x0b\tK\x0b?\rW\x03]\x0f_\x03\x0b\tC\x0b?\rC\x03E\x0fc\x05\x1f\x05!\x1d!\x07\x05#\x03\x03%e\x05%\x05\'\x03\x11+g-A/i1G3k5m7G9q\x05)\x05+\x05-\x05/\x051\x053\x055\x057\x03\x03=E\x059#\x0b\x1d;\x03\x03a\x1d=\x03\x01\x1f\x13!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03M\r\x05OQSU\x1d?\x1dA\x1dC\x1dE\x03\x03Y\r\x03[A\x1dG\x1dI\x1dK\r\x01\x1dM\x1f\x07"\x02\x18\x00\x00\x00\x14\x00$\x00\x00\x00\x00\x00\x08\x00\x0c\x00\x10\x00\x14\x00\x07\x00\x18\x00\x14\x00\x00\x00\x00\x00\x00\x01T\x00\x00\x008\x00\x00\x00\x1c\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x05\x1dO\x05\x01\x03\x05oI\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03I)\x05\r\x11\r)\x05\r\x11\x05\t)\x03B\x04\x0f\x13\x11\x03\x03\x03\x01\x03\x05!)\x03\x05\t)\x03\t\t\x04\x8f\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x17\x05\x03\x05\x0b\x03\x03\x01\r\x07\x05;\x03\x01\x03\x01\x05\x04\x01\x03\x03\x03\x11\x05\x19\x05\x03\t\x13\x03\x03\x01\x07\x06\x1f\x03\x01\x03\x01\t\x03\x11#\x03\x07\x0b\x07\x11)\x03\x01\x05\x05\x03\x05\x04\x05\x03\x07\x06\x03\x01\x05\x01\x00\xc2\x0eQ\x13\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!!)#\x1f\x19\x91\r\xaf\x83\x82\x04\x13\x1f\x15\x1d\x15\x13\x11\x1f\x19\x17\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00convert_v1\x00constant_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00value\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00ducc_fft\x00',
xla_call_module_version=4,
) # End paste
# Pasted from the test output (see back_compat_test.py module docstring) # Pasted from the test output (see back_compat_test.py module docstring)
data_2023_06_14 = dict( data_2023_06_14 = dict(
testdata_version=1, testdata_version=1,

View File

@ -790,7 +790,7 @@ def _check_lowering(lowering) -> None:
# Their backwards compatibility is tested by back_compat_test.py. # Their backwards compatibility is tested by back_compat_test.py.
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32", "dynamic_ducc_fft", "cu_threefry2x32",
# cholesky on CPU # cholesky on CPU
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
# eigh on CPU # eigh on CPU

View File

@ -48,9 +48,8 @@ nb::bytes BuildDynamicDuccFftDescriptor(
nb::dict Registrations() { nb::dict Registrations() {
nb::dict dict; nb::dict dict;
// TODO(b/287702203): this must be kept until EOY 2023 for backwards // TODO(b/311175955): this must be kept until May 2024 for backwards
// of serialized functions using fft. // of serialized functions using fft.
dict["ducc_fft"] = EncapsulateFunction(DuccFft);
dict["dynamic_ducc_fft"] = EncapsulateFunction(DynamicDuccFft); dict["dynamic_ducc_fft"] = EncapsulateFunction(DynamicDuccFft);
return dict; return dict;
} }

View File

@ -26,17 +26,6 @@ enum DuccFftType : byte {
R2C = 2, R2C = 2,
} }
table DuccFftDescriptor {
dtype:DuccFftDtype;
fft_type:DuccFftType;
shape:[uint64];
strides_in:[uint64];
strides_out:[uint64];
axes:[uint32];
forward:bool;
scale:double;
}
table DynamicDuccFftDescriptor { table DynamicDuccFftDescriptor {
ndims:uint32; ndims:uint32;
dtype:DuccFftDtype; dtype:DuccFftDtype;
@ -45,4 +34,4 @@ table DynamicDuccFftDescriptor {
forward:bool; forward:bool;
} }
root_type DuccFftDescriptor; root_type DynamicDuccFftDescriptor;

View File

@ -98,23 +98,8 @@ void DuccFftImpl(void *out, void *operand, jax::DuccFftDtype dtype,
} // namespace } // namespace
// TODO(b/287702203): this must be kept until EOY 2023 for backwards // TODO(b/311175955): this must be kept until May 2024 for backwards
// of serialized functions using fft. // of serialized functions using fft.
void DuccFft(void *out, void **in, XlaCustomCallStatus *) {
const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]);
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
stride_t strides_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
stride_t strides_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(),
shape, strides_in, strides_out, axes,
descriptor->forward(), descriptor->scale());
}
void DynamicDuccFft(void *out, void **in, XlaCustomCallStatus *) { void DynamicDuccFft(void *out, void **in, XlaCustomCallStatus *) {
// in[0]=descriptor, in[1]=operand, // in[0]=descriptor, in[1]=operand,
// in[2]=shape, in[3]=strides_in, in[4]=strides_out, in[5]=scale. // in[2]=shape, in[3]=strides_in, in[4]=strides_out, in[5]=scale.

View File

@ -20,10 +20,9 @@ limitations under the License.
namespace jax { namespace jax {
// TODO(b/287702203): this must be kept until EOY 2023 for backwards
// of serialized functions using fft.
void DuccFft(void* out, void** in, XlaCustomCallStatus*);
// TODO(b/311175955): this must be kept until May 2024 for backwards
// of serialized functions using fft.
void DynamicDuccFft(void* out, void** in, XlaCustomCallStatus*); void DynamicDuccFft(void* out, void** in, XlaCustomCallStatus*);
} // namespace jax } // namespace jax

View File

@ -100,7 +100,7 @@ class CompatTest(bctu.CompatTestBase):
# Add here all the testdatas that should cover the targets guaranteed # Add here all the testdatas that should cover the targets guaranteed
# stable # stable
covering_testdatas = [ covering_testdatas = [
cpu_ducc_fft.data_2023_03_17, cpu_ducc_fft.data_2023_06_14, cpu_ducc_fft.data_2023_06_14,
cpu_cholesky_lapack_potrf.data_2023_06_19, cpu_cholesky_lapack_potrf.data_2023_06_19,
cpu_eig_lapack_geev.data_2023_06_19, cpu_eig_lapack_geev.data_2023_06_19,
cpu_eigh_lapack_syev.data_2023_03_17, cpu_eigh_lapack_syev.data_2023_03_17,
@ -142,11 +142,6 @@ class CompatTest(bctu.CompatTestBase):
return lax.fft(x, fft_type="fft", fft_lengths=(4,)) return lax.fft(x, fft_type="fft", fft_lengths=(4,))
# TODO(b/311175955): Remove this test and the corresponding custom calls. # TODO(b/311175955): Remove this test and the corresponding custom calls.
# An old lowering, with ducc_fft.
data = self.load_testdata(cpu_ducc_fft.data_2023_03_17)
self.run_one_test(func, data, expect_current_custom_calls=[])
# A newer lowering, with dynamic_ducc_fft. # A newer lowering, with dynamic_ducc_fft.
data = self.load_testdata(cpu_ducc_fft.data_2023_06_14) data = self.load_testdata(cpu_ducc_fft.data_2023_06_14)
# FFT no longer lowers to a custom call. # FFT no longer lowers to a custom call.