diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py index d92a90b36..6418f3742 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py @@ -15,36 +15,6 @@ import datetime from numpy import array, float32, complex64 -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17 = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['ducc_fft'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.]], dtype=float32),), - expected_outputs=(array([[ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], - [22.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], - [38.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]], dtype=complex64),), - mlir_module_text=r""" -module @jit_func { - func.func public @main(%arg0: tensor<3x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<3x4xcomplex> {jax.result_info = ""}) { - %0 = call @fft(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xcomplex> - return %0 : tensor<3x4xcomplex> - } - func.func private @fft(%arg0: tensor<3x4xf32>) -> tensor<3x4xcomplex> { - %0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xcomplex> - %1 = stablehlo.constant dense<"0x18000000140024000000000008000C001000140007001800140000000000000154000000380000001C00000010000000000000000000F03F0000000001000000010000000200000004000000000000000100000000000000000000000200000004000000000000000100000000000000000000000200000003000000000000000400000000000000"> : tensor<136xui8> - %2 = stablehlo.custom_call @ducc_fft(%1, %0) {api_version = 2 : i32, operand_layouts = [dense<0> : tensor<1xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<136xui8>, tensor<3x4xcomplex>) -> tensor<3x4xcomplex> - return %2 : tensor<3x4xcomplex> - } -} -""", - mlir_module_serialized=b'ML\xefR\x03MLIRxxx-trunk\x00\x01\x1d\x05\x01\x05\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\x99s\x15\x01?\x07\x0b\x0f\x17\x0b\x0b\x0b\x0b\x0f\x13\x0b33\x0b\x0b\x0f\x0b\x13\x0b\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x035\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0bf\x04\x0b\x0b\x0b\x13/\x0f\x03\x15\x17\x17\x07\x17\x07\x17\x0b\x07\x13\x13\x02\xca\x05\x1f\x05\x13\x1d\x1b\x07\x17\x1d^\x03\x01\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\'\x07\x03\x03\x03\x15\x05\x1d\x03\x0b\tK\x0b?\rW\x03]\x0f_\x03\x0b\tC\x0b?\rC\x03E\x0fc\x05\x1f\x05!\x1d!\x07\x05#\x03\x03%e\x05%\x05\'\x03\x11+g-A/i1G3k5m7G9q\x05)\x05+\x05-\x05/\x051\x053\x055\x057\x03\x03=E\x059#\x0b\x1d;\x03\x03a\x1d=\x03\x01\x1f\x13!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03M\r\x05OQSU\x1d?\x1dA\x1dC\x1dE\x03\x03Y\r\x03[A\x1dG\x1dI\x1dK\r\x01\x1dM\x1f\x07"\x02\x18\x00\x00\x00\x14\x00$\x00\x00\x00\x00\x00\x08\x00\x0c\x00\x10\x00\x14\x00\x07\x00\x18\x00\x14\x00\x00\x00\x00\x00\x00\x01T\x00\x00\x008\x00\x00\x00\x1c\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x05\x1dO\x05\x01\x03\x05oI\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03I)\x05\r\x11\r)\x05\r\x11\x05\t)\x03B\x04\x0f\x13\x11\x03\x03\x03\x01\x03\x05!)\x03\x05\t)\x03\t\t\x04\x8f\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x17\x05\x03\x05\x0b\x03\x03\x01\r\x07\x05;\x03\x01\x03\x01\x05\x04\x01\x03\x03\x03\x11\x05\x19\x05\x03\t\x13\x03\x03\x01\x07\x06\x1f\x03\x01\x03\x01\t\x03\x11#\x03\x07\x0b\x07\x11)\x03\x01\x05\x05\x03\x05\x04\x05\x03\x07\x06\x03\x01\x05\x01\x00\xc2\x0eQ\x13\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!!)#\x1f\x19\x91\r\xaf\x83\x82\x04\x13\x1f\x15\x1d\x15\x13\x11\x1f\x19\x17\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00convert_v1\x00constant_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00value\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00ducc_fft\x00', - xla_call_module_version=4, -) # End paste - # Pasted from the test output (see back_compat_test.py module docstring) data_2023_06_14 = dict( testdata_version=1, diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index 5540219ec..c1d4dda49 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -790,7 +790,7 @@ def _check_lowering(lowering) -> None: # Their backwards compatibility is tested by back_compat_test.py. _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", - "ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32", + "dynamic_ducc_fft", "cu_threefry2x32", # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on CPU diff --git a/jaxlib/cpu/ducc_fft.cc b/jaxlib/cpu/ducc_fft.cc index 33e73c5f4..a8f0490ac 100644 --- a/jaxlib/cpu/ducc_fft.cc +++ b/jaxlib/cpu/ducc_fft.cc @@ -48,9 +48,8 @@ nb::bytes BuildDynamicDuccFftDescriptor( nb::dict Registrations() { nb::dict dict; - // TODO(b/287702203): this must be kept until EOY 2023 for backwards + // TODO(b/311175955): this must be kept until May 2024 for backwards // of serialized functions using fft. - dict["ducc_fft"] = EncapsulateFunction(DuccFft); dict["dynamic_ducc_fft"] = EncapsulateFunction(DynamicDuccFft); return dict; } diff --git a/jaxlib/cpu/ducc_fft.fbs b/jaxlib/cpu/ducc_fft.fbs index a58e1dc7c..bc8572ad5 100644 --- a/jaxlib/cpu/ducc_fft.fbs +++ b/jaxlib/cpu/ducc_fft.fbs @@ -26,17 +26,6 @@ enum DuccFftType : byte { R2C = 2, } -table DuccFftDescriptor { - dtype:DuccFftDtype; - fft_type:DuccFftType; - shape:[uint64]; - strides_in:[uint64]; - strides_out:[uint64]; - axes:[uint32]; - forward:bool; - scale:double; -} - table DynamicDuccFftDescriptor { ndims:uint32; dtype:DuccFftDtype; @@ -45,4 +34,4 @@ table DynamicDuccFftDescriptor { forward:bool; } -root_type DuccFftDescriptor; +root_type DynamicDuccFftDescriptor; diff --git a/jaxlib/cpu/ducc_fft_kernels.cc b/jaxlib/cpu/ducc_fft_kernels.cc index eab35905f..12f8327b6 100644 --- a/jaxlib/cpu/ducc_fft_kernels.cc +++ b/jaxlib/cpu/ducc_fft_kernels.cc @@ -98,23 +98,8 @@ void DuccFftImpl(void *out, void *operand, jax::DuccFftDtype dtype, } // namespace -// TODO(b/287702203): this must be kept until EOY 2023 for backwards +// TODO(b/311175955): this must be kept until May 2024 for backwards // of serialized functions using fft. -void DuccFft(void *out, void **in, XlaCustomCallStatus *) { - const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]); - shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end()); - stride_t strides_in(descriptor->strides_in()->begin(), - descriptor->strides_in()->end()); - stride_t strides_out(descriptor->strides_out()->begin(), - descriptor->strides_out()->end()); - shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end()); - - DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(), - shape, strides_in, strides_out, axes, - descriptor->forward(), descriptor->scale()); -} - - void DynamicDuccFft(void *out, void **in, XlaCustomCallStatus *) { // in[0]=descriptor, in[1]=operand, // in[2]=shape, in[3]=strides_in, in[4]=strides_out, in[5]=scale. diff --git a/jaxlib/cpu/ducc_fft_kernels.h b/jaxlib/cpu/ducc_fft_kernels.h index a3bf6cf46..13d0b1d40 100644 --- a/jaxlib/cpu/ducc_fft_kernels.h +++ b/jaxlib/cpu/ducc_fft_kernels.h @@ -20,10 +20,9 @@ limitations under the License. namespace jax { -// TODO(b/287702203): this must be kept until EOY 2023 for backwards -// of serialized functions using fft. -void DuccFft(void* out, void** in, XlaCustomCallStatus*); +// TODO(b/311175955): this must be kept until May 2024 for backwards +// of serialized functions using fft. void DynamicDuccFft(void* out, void** in, XlaCustomCallStatus*); } // namespace jax diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 8a4a8304f..d2feef2e1 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -100,7 +100,7 @@ class CompatTest(bctu.CompatTestBase): # Add here all the testdatas that should cover the targets guaranteed # stable covering_testdatas = [ - cpu_ducc_fft.data_2023_03_17, cpu_ducc_fft.data_2023_06_14, + cpu_ducc_fft.data_2023_06_14, cpu_cholesky_lapack_potrf.data_2023_06_19, cpu_eig_lapack_geev.data_2023_06_19, cpu_eigh_lapack_syev.data_2023_03_17, @@ -142,11 +142,6 @@ class CompatTest(bctu.CompatTestBase): return lax.fft(x, fft_type="fft", fft_lengths=(4,)) # TODO(b/311175955): Remove this test and the corresponding custom calls. - - # An old lowering, with ducc_fft. - data = self.load_testdata(cpu_ducc_fft.data_2023_03_17) - self.run_one_test(func, data, expect_current_custom_calls=[]) - # A newer lowering, with dynamic_ducc_fft. data = self.load_testdata(cpu_ducc_fft.data_2023_06_14) # FFT no longer lowers to a custom call.