From b9c0658fcf1a98438f3769d7d791ae378f461c0c Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 17 Jun 2023 04:50:12 -0700 Subject: [PATCH] Add support for dynamic shapes to jax.fft. The idea is that we take all the values that can contain dimension sizes from the descriptor (shape, strides_in, strides_out) and we pass them as 1-d tensor operands. We also pass as an operand the output_shape, so that we can use the hlo.CustomCallOp `indices_of_output_shapes` attribute to tell the shape refinement how to compute the shape of the result. We keep the old descriptor and the ducc_fft registration for the old C++ custom targets for backwards compatibility (for 6 months). That behavior is tested by back_compat_test.py. The one downside of this implementation is that it moves some of the ducc-specific logic from ducc_fft.py (in jaxlib) into fft.py (in jax). This was necessary because that code computes with dimensions that are now dynamic. In JAX we have support for evaluating dynamic shapes and turning them into 1-d tensors. Also added backwards compatibility test for dynamic_ducc_fft and kept the old test for ducc_fft. PiperOrigin-RevId: 541168692 --- jax/_src/lax/fft.py | 81 +++++++++++- jax/experimental/jax2tf/jax_export.py | 2 +- .../jax2tf/tests/back_compat_test.py | 26 +++- .../back_compat_testdata/cpu_ducc_fft.py | 68 ++++++++++ .../jax2tf/tests/primitives_test.py | 2 +- .../jax2tf/tests/shape_poly_test.py | 1 - jaxlib/cpu/cpu_kernels.cc | 5 +- jaxlib/cpu/ducc_fft.cc | 36 +++--- jaxlib/cpu/ducc_fft.fbs | 8 ++ jaxlib/cpu/ducc_fft_kernels.cc | 118 ++++++++++++------ jaxlib/cpu/ducc_fft_kernels.h | 4 + jaxlib/ducc_fft.py | 113 +++++------------ 12 files changed, 310 insertions(+), 154 deletions(-) diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 294256516..6beb8c118 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -30,6 +30,7 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client from jax._src.lib import ducc_fft +from jax._src.lib import version as jaxlib_version from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact __all__ = [ @@ -103,6 +104,9 @@ def fft_abstract_eval(x, fft_type, fft_lengths): return x.update(shape=shape, dtype=dtype) def _fft_lowering(ctx, x, *, fft_type, fft_lengths): + if not is_constant_shape(fft_lengths): + # TODO: https://github.com/openxla/stablehlo/issues/1366 + raise NotImplementedError("Shape polymorphism for FFT with non-constant fft_length is not implemented for TPU and GPU") return [ hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name), mlir.dense_int_elements(fft_lengths)).result @@ -110,11 +114,80 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths): def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths): - if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)): - raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778") x_aval, = ctx.avals_in - return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type, - fft_lengths=fft_lengths)] + if jaxlib_version < (0, 4, 13): + if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)): + raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778; try updating your jaxlib.") + return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type, # type: ignore + fft_lengths=fft_lengths)] + + in_shape = x_aval.shape + dtype = x_aval.dtype + out_aval, = ctx.avals_out + out_shape = out_aval.shape + + forward = fft_type in (xla_client.FftType.FFT, xla_client.FftType.RFFT) + ndims = len(in_shape) + assert len(fft_lengths) >= 1 + assert len(fft_lengths) <= ndims, (fft_lengths, ndims) + assert len(in_shape) == len(out_shape) == ndims + + # PocketFft does not allow size 0 dimensions. + if 0 in in_shape or 0 in out_shape: + if fft_type == xla_client.FftType.RFFT: + assert dtype in (np.float32, np.float64), dtype + out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128) + + elif fft_type == xla_client.FftType.IRFFT: + assert np.issubdtype(dtype, np.complexfloating), dtype + out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64) + + else: + assert np.issubdtype(dtype, np.complexfloating), dtype + out_dtype = dtype + + zero = mlir.ir_constant(np.array(0, dtype=out_dtype), + canonicalize_types=False) + return [ + mlir.broadcast_in_dim(ctx, zero, out_aval, broadcast_dimensions=[])] + + strides_in = [] + stride = 1 + for d in reversed(in_shape): + strides_in.append(stride) + stride *= d + strides_in = mlir.shape_tensor( + mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_in)))) + + strides_out = [] + stride = 1 + for d in reversed(out_shape): + strides_out.append(stride) + stride *= d + strides_out = mlir.shape_tensor( + mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_out)))) + + # scale = 1. if forward else (1. / np.prod(fft_lengths)) as a f64[1] tensor + double_type = mlir.ir.RankedTensorType.get((), mlir.ir.F64Type.get()) + size_fft_length_prod = np.prod(fft_lengths) if fft_lengths else 1 + size_fft_lengths, = mlir.eval_dynamic_shape_as_vals(ctx, (size_fft_length_prod,)) + size_fft_lengths = hlo.ConvertOp(double_type, size_fft_lengths) + one = mlir.ir_constant(np.float64(1.), canonicalize_types=False) + scale = one if forward else hlo.DivOp(one, size_fft_lengths) + scale = hlo.ReshapeOp( + mlir.ir.RankedTensorType.get((1,), mlir.ir.F64Type.get()), + scale).result + + in_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, in_shape)) + out_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, out_shape)) + in_shape = in_shape if fft_type != xla_client.FftType.IRFFT else out_shape + + result_type = mlir.aval_to_ir_type(out_aval) + return [ducc_fft.dynamic_ducc_fft_hlo( + result_type, x, + input_dtype=x_aval.dtype, ndims=ndims, input_shape=in_shape, + strides_in=strides_in, strides_out=strides_out, scale=scale, + fft_type=fft_type, fft_lengths=fft_lengths, result_shape=out_shape)] def _naive_rfft(x, fft_lengths): y = fft(x, xla_client.FftType.FFT, fft_lengths) diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 34d1d7236..2a22adc04 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -687,7 +687,7 @@ def _check_lowering(lowering) -> None: # Their backwards compatibility is tested by back_compat_test.py. _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", - "ducc_fft", "cu_threefry2x32", + "ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32", # eigh on CPU "lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd", # eigh on GPU diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index 106f47d63..6303370a6 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -194,7 +194,8 @@ class CompatTest(jtu.JaxTestCase): atol: Optional[float] = None, allow_additional_custom_call_targets: Sequence[str] = (), check_results: Optional[Callable[..., None]] = None, - use_tf_graph: bool = False): + use_tf_graph=False, + compare_with_current: bool = True): """Run one compatibility test. Args: @@ -203,13 +204,21 @@ class CompatTest(jtu.JaxTestCase): rtol: relative tolerance for numerical comparisons atol: absolute tolerance for numerical comparisons check_results: invoked with the results obtained from running the - serialized code, and those stored in the test data, and the kwarg rtol. + serialized code, and those stored in the test data, and the kwargs rtol + and atol. allow_additional_custom_call_targets: additional custom call targets to allow. use_tf_graph: if False (default), uses jax_export to serialize JAX functions and to invoke them. If True then uses tf.Graph to serialize and run the functions; expects that `func` contains a `jax2tf.call_tf` and uses `jax2tf.convert` to generate a tf.Graph containing a XlaCallModule with the actual MLIR module. + compare_with_current: whether to compare the current behavior for + `func` with the one stored in `data`. If `True` (default) uses the + current version of JAX and XLA to lower and serialize `func` and check + its results compared to the stored ones; it also dumps the current + test data. If `False`, no current serialization are comparisons are + done, tests only the saved serialization. Use this option for a test + data for which we have changed the serialization. """ if not isinstance(data, CompatTestData): raise ValueError(f"Expecting data: CompatTestData but got {data}. " @@ -324,7 +333,8 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( else: self.assertAllClose(res_from_serialized_run_now, data.expected_outputs, rtol=rtol, atol=atol) - self.assertListEqual(custom_call_targets, data.custom_call_targets) + if compare_with_current: + self.assertListEqual(custom_call_targets, data.custom_call_targets) def run_serialized(self, data: CompatTestData, use_tf_graph=False): @@ -398,7 +408,8 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( # Add here all the testdatas that should cover the targets guaranteed # stable covering_testdatas = [ - cpu_ducc_fft.data_2023_03_17, cpu_lapack_syev.data_2023_03_17, + cpu_ducc_fft.data_2023_03_17, cpu_ducc_fft.data_2023_06_14, + cpu_lapack_syev.data_2023_03_17, cpu_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15, cuda_cusolver_geqrf.data_2023_03_18, cuda_cusolver_syev.data_2023_03_17, tf_call_tf_function.data_2023_06_02, @@ -422,9 +433,16 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( def func(x): return lax.fft(x, fft_type="fft", fft_lengths=(4,)) + # An old lowering, with ducc_fft. We keep it for 6 months. data = load_testdata(cpu_ducc_fft.data_2023_03_17) + # We have changed the lowering for fft, do not compare with current. + self.run_one_test(func, data, compare_with_current=False) + + # A newer lowering, with dynamic_ducc_fft. + data = load_testdata(cpu_ducc_fft.data_2023_06_14) self.run_one_test(func, data) + @staticmethod def eigh_input(shape, dtype): # In order to keep inputs small, we construct the input programmatically diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/cpu_ducc_fft.py b/jax/experimental/jax2tf/tests/back_compat_testdata/cpu_ducc_fft.py index 3b48707cc..a3266b285 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/cpu_ducc_fft.py +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/cpu_ducc_fft.py @@ -44,3 +44,71 @@ module @jit_func { mlir_module_serialized=b'ML\xefR\x03MLIRxxx-trunk\x00\x01\x1d\x05\x01\x05\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\x99s\x15\x01?\x07\x0b\x0f\x17\x0b\x0b\x0b\x0b\x0f\x13\x0b33\x0b\x0b\x0f\x0b\x13\x0b\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x035\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0bf\x04\x0b\x0b\x0b\x13/\x0f\x03\x15\x17\x17\x07\x17\x07\x17\x0b\x07\x13\x13\x02\xca\x05\x1f\x05\x13\x1d\x1b\x07\x17\x1d^\x03\x01\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\'\x07\x03\x03\x03\x15\x05\x1d\x03\x0b\tK\x0b?\rW\x03]\x0f_\x03\x0b\tC\x0b?\rC\x03E\x0fc\x05\x1f\x05!\x1d!\x07\x05#\x03\x03%e\x05%\x05\'\x03\x11+g-A/i1G3k5m7G9q\x05)\x05+\x05-\x05/\x051\x053\x055\x057\x03\x03=E\x059#\x0b\x1d;\x03\x03a\x1d=\x03\x01\x1f\x13!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03M\r\x05OQSU\x1d?\x1dA\x1dC\x1dE\x03\x03Y\r\x03[A\x1dG\x1dI\x1dK\r\x01\x1dM\x1f\x07"\x02\x18\x00\x00\x00\x14\x00$\x00\x00\x00\x00\x00\x08\x00\x0c\x00\x10\x00\x14\x00\x07\x00\x18\x00\x14\x00\x00\x00\x00\x00\x00\x01T\x00\x00\x008\x00\x00\x00\x1c\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x05\x1dO\x05\x01\x03\x05oI\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03I)\x05\r\x11\r)\x05\r\x11\x05\t)\x03B\x04\x0f\x13\x11\x03\x03\x03\x01\x03\x05!)\x03\x05\t)\x03\t\t\x04\x8f\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x17\x05\x03\x05\x0b\x03\x03\x01\r\x07\x05;\x03\x01\x03\x01\x05\x04\x01\x03\x03\x03\x11\x05\x19\x05\x03\t\x13\x03\x03\x01\x07\x06\x1f\x03\x01\x03\x01\t\x03\x11#\x03\x07\x0b\x07\x11)\x03\x01\x05\x05\x03\x05\x04\x05\x03\x07\x06\x03\x01\x05\x01\x00\xc2\x0eQ\x13\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!!)#\x1f\x19\x91\r\xaf\x83\x82\x04\x13\x1f\x15\x1d\x15\x13\x11\x1f\x19\x17\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00convert_v1\x00constant_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00value\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00ducc_fft\x00', xla_call_module_version=4, ) # End paste + +# Pasted from the test output (see back_compat_test.py module docstring) +data_2023_06_14 = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['dynamic_ducc_fft'], + serialized_date=datetime.date(2023, 6, 14), + inputs=(array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]], dtype=float32),), + expected_outputs=(array([[ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], + [22.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], + [38.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]], dtype=complex64),), + mlir_module_text=r""" +#loc = loc(unknown) +module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<3x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<3x4xcomplex> {jax.result_info = ""}) { + %0 = call @fft(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xcomplex> loc(#loc3) + return %0 : tensor<3x4xcomplex> loc(#loc) + } loc(#loc) + func.func private @fft(%arg0: tensor<3x4xf32> loc(unknown)) -> tensor<3x4xcomplex> { + %0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xcomplex> loc(#loc4) + %1 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2 = stablehlo.constant dense<1> : tensor loc(#loc5) + %3 = stablehlo.convert %1 : (tensor) -> tensor loc(#loc5) + %4 = stablehlo.reshape %3 : (tensor) -> tensor<1xi32> loc(#loc5) + %5 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) + %6 = stablehlo.reshape %5 : (tensor) -> tensor<1xi32> loc(#loc5) + %7 = stablehlo.concatenate %4, %6, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) + %8 = stablehlo.constant dense<4> : tensor loc(#loc5) + %9 = stablehlo.constant dense<1> : tensor loc(#loc5) + %10 = stablehlo.convert %8 : (tensor) -> tensor loc(#loc5) + %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) + %12 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) + %13 = stablehlo.reshape %12 : (tensor) -> tensor<1xi32> loc(#loc5) + %14 = stablehlo.concatenate %11, %13, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) + %15 = stablehlo.constant dense<4> : tensor loc(#loc5) + %16 = stablehlo.convert %15 : (tensor) -> tensor loc(#loc5) + %17 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc5) + %18 = stablehlo.reshape %17 : (tensor) -> tensor<1xf64> loc(#loc5) + %19 = stablehlo.constant dense<3> : tensor loc(#loc5) + %20 = stablehlo.constant dense<4> : tensor loc(#loc5) + %21 = stablehlo.convert %19 : (tensor) -> tensor loc(#loc5) + %22 = stablehlo.reshape %21 : (tensor) -> tensor<1xi32> loc(#loc5) + %23 = stablehlo.convert %20 : (tensor) -> tensor loc(#loc5) + %24 = stablehlo.reshape %23 : (tensor) -> tensor<1xi32> loc(#loc5) + %25 = stablehlo.concatenate %22, %24, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) + %26 = stablehlo.constant dense<3> : tensor loc(#loc5) + %27 = stablehlo.constant dense<4> : tensor loc(#loc5) + %28 = stablehlo.convert %26 : (tensor) -> tensor loc(#loc5) + %29 = stablehlo.reshape %28 : (tensor) -> tensor<1xi32> loc(#loc5) + %30 = stablehlo.convert %27 : (tensor) -> tensor loc(#loc5) + %31 = stablehlo.reshape %30 : (tensor) -> tensor<1xi32> loc(#loc5) + %32 = stablehlo.concatenate %29, %31, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) + %33 = stablehlo.constant dense<[20, 0, 0, 0, 0, 0, 14, 0, 16, 0, 8, 0, 0, 0, 0, 0, 12, 0, 7, 0, 14, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]> : tensor<44xui8> loc(#loc5) + %34 = stablehlo.custom_call @dynamic_ducc_fft(%33, %0, %25, %7, %14, %18, %32) {api_version = 2 : i32, indices_of_shape_operands = dense<6> : tensor<1xi64>, operand_layouts = [dense<0> : tensor<1xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<44xui8>, tensor<3x4xcomplex>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xf64>, tensor<2xi32>) -> tensor<3x4xcomplex> loc(#loc5) + return %34 : tensor<3x4xcomplex> loc(#loc3) + } loc(#loc3) +} loc(#loc) +#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":437:0) +#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":428:0) +#loc3 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]"(#loc2)) +#loc5 = loc("jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01!\x05\x01\x03\x01\x03\x05\x03\x11\x07\t\x0b\r\x0f\x11\x13\x15\x03\xd3\x95+\x01U\x0f\x07\x13\x0b\x13\x0b\x0f\x0f\x0b\x0b\x0b\x0b\x0b\x17\x13\x13#\x0b\x0b\x0b33\x0b\x17\x0f\x0b\x0b\x0b\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x03A/\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b//\x0f//\xbf\x0b\x0b\x0b/'\x0f\x01\x03\x0f\x03)\x0f\x0f\x13\x17\x13\x17\x07\x07\x0f\x07\x07\x13\x07\x17\x0b\x13\x07\x13\x13\x13\x02r\x06\x1d5\x1b\x1f\x03\x03\x07}\x05\x17\x03\x037\x81\x05\x19\x1d-/\x11\x01\x05\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x17\x19\xb2\x06\x01\x03\x03\x07\x7f\x03\x03\x07\x85\x03\x07#\x0f%\x0f\x0b'\x05%\x05'\x05)\x03\x0b\x11c\x13W\x15o\x0bu\x17w\x03\x0b\x11[\x13W\x15[\x0b]\x17{\x05+\x17\x19\xd6\x06\x01\x1d3\x1b\x05-\x05/\x051\x03\x03\x07\x83\x03\x03\x07\x87\x03\x13?\x89AYC\x8bE_G\x8dI\x8fK\x91M_O\x93\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03S]\x05E\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1d\x1dG\x03\x03y\x1dI\x03\x01\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03e\r\x05gikm\x1dK\x1dM\x1dO\x1dQ\x03\x03q\r\x03sY\x1dS\x1dU\x1dW\r\x01\x1dY\x1f\x03\x11\x04\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0f\x01\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf0?\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19Y\x14\x00\x00\x00\x00\x00\x0e\x00\x10\x00\x08\x00\x00\x00\x00\x00\x0c\x00\x07\x00\x0e\x00\x00\x00\x00\x00\x00\x01\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x0b\x05\x1d[\x05\x01\x1f%\x11\x06\x00\x00\x00\x00\x00\x00\x00\x03\x0fUaUUUUU\x03\x03a\x01\x02\x02)\x01\x0f)\x01\x11)\x03\x05\x11)\x05\r\x11\x1f)\x03\t\x11)\x05\r\x11\x15\x1d\x1b)\x01\x17\t\x0b)\x03\xb1#\x13\x11\x03\r\x03\t\x03\x15)\x03\x05\x17!)\x03\x05\x0f)\x03\x05\x1b)\x03\t\x1b\x04\xaa\x04\x05\x01\x11\x03!\x07\x03\x01\t\x0b\x11\x03)\x05\x03\x05\x0b\x03\r\x03\x11\x07\rQ\x03\t\x03\x01\r\x04\x03\x03\x03\x0b\x11\r+\x05\x03I\x93\x03\r\x03\x05\x061\x03\t\x03\x01\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x05\x07\x06\x01\x03\x07\x03\t\x05\x06\x01\x03\x05\x03\x07\x07\x06\x01\x03\x07\x03\r\t\x07\x01\t\x03\x0b\x05\x0b\x0f\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x13\x07\x06\x01\x03\x07\x03\x17\x05\x06\x01\x03\x05\x03\x15\x07\x06\x01\x03\x07\x03\x1b\t\x07\x01\t\x03\x0b\x05\x19\x1d\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x13\x03!\x03\x03\x019\x03\x13\x07\x06\x01\x03!\x03%\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x03)\x07\x06\x01\x03\x07\x03-\x05\x06\x01\x03\x05\x03+\x07\x06\x01\x03\x07\x031\t\x07\x01\t\x03\x0b\x05/3\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x037\x07\x06\x01\x03\x07\x03;\x05\x06\x01\x03\x05\x039\x07\x06\x01\x03\x07\x03?\t\x07\x01\t\x03\x0b\x05=A\x03\x03\x01;\x03\x19\x0f\x07\x01=\x03\t\x0fE\x035\x11\x1f'C\r\x04\r\x03G\x06\x03\x01\x05\x01\x00\xc6\x0e]#\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!5!)#\x1f\x19\x15\x91\xaf\xbe\x02\x13%)\x83\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x11\x1f\x17\x17\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00convert_v1\x00reshape_v1\x00concatenate_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00indices_of_shape_operands\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00dynamic_ducc_fft\x00", + xla_call_module_version=6, +) # End paste diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 2e3c0a25e..9b4f4789d 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -14,7 +14,7 @@ """Tests for JAX primitive coverage. The bulk of the testing is done by `test_prim`, which is parameterized by -about 2000+ test harnesses. See `primitive_harness.py` docstring for a +about 3500+ test harnesses. See `primitive_harness.py` docstring for a description of test harnesses. That module contains also the definitions of all the test harnesses, and a specification of which are only partially implemented for JAX. diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index ae1c316d0..023d70c41 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2742,7 +2742,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): # Set of harness.group_name:platform that are implemented with custom call custom_call_harnesses = { "vmap_cholesky:cpu", "vmap_cholesky:gpu", - "vmap_fft:cpu", "fft:cpu", "householder_product:cpu", "householder_product:gpu", "vmap_geqrf:cpu", "vmap_geqrf:gpu", "vmap_lu:cpu", "vmap_lu:gpu", diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 4cdbcd028..dc5d657a9 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -105,7 +105,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "lapack_cgees", ComplexGees>::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "lapack_zgees", ComplexGees>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("ducc_fft", DuccFft, "Host"); +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( + "ducc_fft", DuccFft, "Host"); +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( + "dynamic_ducc_fft", DynamicDuccFft, "Host"); } // namespace } // namespace jax diff --git a/jaxlib/cpu/ducc_fft.cc b/jaxlib/cpu/ducc_fft.cc index 0bcfa8920..674bc5f9e 100644 --- a/jaxlib/cpu/ducc_fft.cc +++ b/jaxlib/cpu/ducc_fft.cc @@ -27,42 +27,40 @@ namespace py = pybind11; namespace jax { namespace { -py::bytes BuildDuccFftDescriptor(const std::vector &shape, - bool is_double, int fft_type, - const std::vector &fft_lengths, - const std::vector &strides_in, - const std::vector &strides_out, - const std::vector &axes, - bool forward, double scale) { - DuccFftDescriptorT descriptor; - descriptor.shape = shape; + +py::bytes BuildDynamicDuccFftDescriptor( + const uint32_t ndims, + bool is_double, int fft_type, + const std::vector &axes, + bool forward) { + DynamicDuccFftDescriptorT descriptor; + descriptor.ndims = ndims; descriptor.fft_type = static_cast(fft_type); descriptor.dtype = is_double ? DuccFftDtype_COMPLEX128 : DuccFftDtype_COMPLEX64; - descriptor.strides_in = strides_in; - descriptor.strides_out = strides_out; descriptor.axes = axes; descriptor.forward = forward; - descriptor.scale = scale; flatbuffers::FlatBufferBuilder fbb; - fbb.Finish(DuccFftDescriptor::Pack(fbb, &descriptor)); + fbb.Finish(DynamicDuccFftDescriptor::Pack(fbb, &descriptor)); return py::bytes(reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); } py::dict Registrations() { pybind11::dict dict; + // TODO(b/287702203): this must be kept until EOY 2023 for backwards + // of serialized functions using fft. dict["ducc_fft"] = EncapsulateFunction(DuccFft); + dict["dynamic_ducc_fft"] = EncapsulateFunction(DynamicDuccFft); return dict; } PYBIND11_MODULE(_ducc_fft, m) { m.def("registrations", &Registrations); - m.def("ducc_fft_descriptor", &BuildDuccFftDescriptor, py::arg("shape"), - py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"), - py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"), - py::arg("forward"), py::arg("scale")); + m.def("dynamic_ducc_fft_descriptor", &BuildDynamicDuccFftDescriptor, + py::arg("ndims"), py::arg("is_double"), py::arg("fft_type"), + py::arg("axes"), py::arg("forward")); } -} // namespace -} // namespace jax +} // namespace +} // namespace jax diff --git a/jaxlib/cpu/ducc_fft.fbs b/jaxlib/cpu/ducc_fft.fbs index dcc74c360..a58e1dc7c 100644 --- a/jaxlib/cpu/ducc_fft.fbs +++ b/jaxlib/cpu/ducc_fft.fbs @@ -37,4 +37,12 @@ table DuccFftDescriptor { scale:double; } +table DynamicDuccFftDescriptor { + ndims:uint32; + dtype:DuccFftDtype; + fft_type:DuccFftType; + axes:[uint32]; + forward:bool; +} + root_type DuccFftDescriptor; diff --git a/jaxlib/cpu/ducc_fft_kernels.cc b/jaxlib/cpu/ducc_fft_kernels.cc index 27029e328..d4c8e98fd 100644 --- a/jaxlib/cpu/ducc_fft_kernels.cc +++ b/jaxlib/cpu/ducc_fft_kernels.cc @@ -25,76 +25,116 @@ namespace jax { using shape_t = ducc0::fmav_info::shape_t; using stride_t = ducc0::fmav_info::stride_t; -void DuccFft(void *out, void **in, XlaCustomCallStatus *) { - const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]); - shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end()); - stride_t stride_in(descriptor->strides_in()->begin(), - descriptor->strides_in()->end()); - stride_t stride_out(descriptor->strides_out()->begin(), - descriptor->strides_out()->end()); - shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end()); +namespace { - switch (descriptor->fft_type()) { +void DuccFftImpl(void *out, void *operand, jax::DuccFftDtype dtype, + jax::DuccFftType fft_type, + shape_t shape, stride_t strides_in, stride_t strides_out, shape_t axes, + bool forward, double scale) { + + switch (fft_type) { case DuccFftType_C2C: - if (descriptor->dtype() == DuccFftDtype_COMPLEX64) { + if (dtype == DuccFftDtype_COMPLEX64) { ducc0::cfmav> m_in( - reinterpret_cast *>(in[1]), shape, stride_in); + reinterpret_cast *>(operand), shape, strides_in); ducc0::vfmav> m_out( - reinterpret_cast *>(out), shape, stride_out); - ducc0::c2c(m_in, m_out, axes, descriptor->forward(), - static_cast(descriptor->scale())); + reinterpret_cast *>(out), shape, strides_out); + ducc0::c2c(m_in, m_out, axes, forward, static_cast(scale)); } else { ducc0::cfmav> m_in( - reinterpret_cast *>(in[1]), shape, stride_in); + reinterpret_cast *>(operand), shape, strides_in); ducc0::vfmav> m_out( - reinterpret_cast *>(out), shape, stride_out); - ducc0::c2c(m_in, m_out, axes, descriptor->forward(), - static_cast(descriptor->scale())); + reinterpret_cast *>(out), shape, strides_out); + ducc0::c2c(m_in, m_out, axes, forward, scale); } break; case DuccFftType_C2R: - if (descriptor->dtype() == DuccFftDtype_COMPLEX64) { + if (dtype == DuccFftDtype_COMPLEX64) { auto shape_in = shape; shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1; ducc0::cfmav> m_in( - reinterpret_cast *>(in[1]), shape_in, stride_in); + reinterpret_cast *>(operand), + shape_in, strides_in); ducc0::vfmav m_out(reinterpret_cast(out), shape, - stride_out); - ducc0::c2r(m_in, m_out, axes, descriptor->forward(), - static_cast(descriptor->scale())); + strides_out); + ducc0::c2r(m_in, m_out, axes, forward, static_cast(scale)); } else { auto shape_in = shape; shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1; ducc0::cfmav> m_in( - reinterpret_cast *>(in[1]), shape_in, stride_in); + reinterpret_cast *>(operand), + shape_in, strides_in); ducc0::vfmav m_out(reinterpret_cast(out), shape, - stride_out); - ducc0::c2r(m_in, m_out, axes, descriptor->forward(), - static_cast(descriptor->scale())); + strides_out); + ducc0::c2r(m_in, m_out, axes, forward, scale); } break; case DuccFftType_R2C: - if (descriptor->dtype() == DuccFftDtype_COMPLEX64) { + if (dtype == DuccFftDtype_COMPLEX64) { auto shape_out = shape; shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1; - ducc0::cfmav m_in(reinterpret_cast(in[1]), shape, - stride_in); + ducc0::cfmav m_in(reinterpret_cast(operand), shape, + strides_in); ducc0::vfmav> m_out( - reinterpret_cast *>(out), shape_out, stride_out); - ducc0::r2c(m_in, m_out, axes, descriptor->forward(), - static_cast(descriptor->scale())); + reinterpret_cast *>(out), + shape_out, strides_out); + ducc0::r2c(m_in, m_out, axes, forward, static_cast(scale)); } else { auto shape_out = shape; shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1; - ducc0::cfmav m_in(reinterpret_cast(in[1]), shape, - stride_in); + ducc0::cfmav m_in(reinterpret_cast(operand), shape, + strides_in); ducc0::vfmav> m_out( - reinterpret_cast *>(out), shape_out, stride_out); - ducc0::r2c(m_in, m_out, axes, descriptor->forward(), - static_cast(descriptor->scale())); + reinterpret_cast *>(out), + shape_out, strides_out); + ducc0::r2c(m_in, m_out, axes, forward, scale); } break; } } -} // namespace jax +} // namespace + + +// TODO(b/287702203): this must be kept until EOY 2023 for backwards +// of serialized functions using fft. +void DuccFft(void *out, void **in, XlaCustomCallStatus *) { + const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]); + shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end()); + stride_t strides_in(descriptor->strides_in()->begin(), + descriptor->strides_in()->end()); + stride_t strides_out(descriptor->strides_out()->begin(), + descriptor->strides_out()->end()); + shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end()); + + DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(), + shape, strides_in, strides_out, axes, + descriptor->forward(), descriptor->scale()); +} + + +void DynamicDuccFft(void *out, void **in, XlaCustomCallStatus *) { + // in[0]=descriptor, in[1]=operand, + // in[2]=shape, in[3]=strides_in, in[4]=strides_out, in[5]=scale. + const DynamicDuccFftDescriptor *descriptor = + flatbuffers::GetRoot(in[0]); + const std::uint32_t *dynamic_shape = + reinterpret_cast(in[2]); + shape_t shape(dynamic_shape, dynamic_shape + descriptor->ndims()); + const std::uint32_t *dynamic_strides_in = + reinterpret_cast(in[3]); + stride_t strides_in(dynamic_strides_in, + dynamic_strides_in + descriptor->ndims()); + const std::uint32_t *dynamic_strides_out = + reinterpret_cast(in[4]); + stride_t strides_out(dynamic_strides_out, + dynamic_strides_out + descriptor->ndims()); + shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end()); + const double *dynamic_scale = reinterpret_cast(in[5]); + + DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(), + shape, strides_in, strides_out, axes, + descriptor->forward(), *dynamic_scale); +} + +} // namespace jax \ No newline at end of file diff --git a/jaxlib/cpu/ducc_fft_kernels.h b/jaxlib/cpu/ducc_fft_kernels.h index 3a925587c..a3bf6cf46 100644 --- a/jaxlib/cpu/ducc_fft_kernels.h +++ b/jaxlib/cpu/ducc_fft_kernels.h @@ -20,8 +20,12 @@ limitations under the License. namespace jax { +// TODO(b/287702203): this must be kept until EOY 2023 for backwards +// of serialized functions using fft. void DuccFft(void* out, void** in, XlaCustomCallStatus*); +void DynamicDuccFft(void* out, void** in, XlaCustomCallStatus*); + } // namespace jax #endif // JAXLIB_CPU_DUCC_FFT_KERNELS_H_ diff --git a/jaxlib/ducc_fft.py b/jaxlib/ducc_fft.py index 75e2207a1..e2f39605d 100644 --- a/jaxlib/ducc_fft.py +++ b/jaxlib/ducc_fft.py @@ -35,115 +35,60 @@ _C2R = 1 _R2C = 2 -def _ducc_fft_descriptor( - shape: List[int], dtype, fft_type: FftType, fft_lengths: List[int] -) -> Tuple[bytes, np.dtype, List[int]]: - n = len(shape) +def _dynamic_ducc_fft_descriptor( + dtype, ndims: int, fft_type: FftType, fft_lengths: List[int] +) -> Tuple[bytes]: assert len(fft_lengths) >= 1 - assert len(fft_lengths) <= n, (fft_lengths, n) + assert len(fft_lengths) <= ndims, (fft_lengths, ndims) forward = fft_type in (FftType.FFT, FftType.RFFT) is_double = np.finfo(dtype).dtype == np.float64 if fft_type == FftType.RFFT: ducc_fft_type = _R2C - - assert dtype in (np.float32, np.float64), dtype - out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128) - - assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths) - out_shape = list(shape) - out_shape[-1] = out_shape[-1] // 2 + 1 - elif fft_type == FftType.IRFFT: ducc_fft_type = _C2R - assert np.issubdtype(dtype, np.complexfloating), dtype - - out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64) - - assert shape[-len(fft_lengths):-1] == fft_lengths[:-1] - out_shape = list(shape) - out_shape[-1] = fft_lengths[-1] - assert (out_shape[-1] // 2 + 1) == shape[-1] else: ducc_fft_type = _C2C - assert np.issubdtype(dtype, np.complexfloating), dtype - out_dtype = dtype - - assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths) - out_shape = shape - - # PocketFft does not allow size 0 dimensions. - if 0 in shape or 0 in out_shape: - return b"", out_dtype, out_shape - # Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the # C++ kernel to describe the FFT to perform. - strides_in = [] - stride = 1 - for d in reversed(shape): - strides_in.append(stride) - stride *= d + axes = [ndims - len(fft_lengths) + d for d in range(len(fft_lengths))] - strides_out = [] - stride = 1 - for d in reversed(out_shape): - strides_out.append(stride) - stride *= d - - axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))] - - scale = 1. if forward else (1. / np.prod(fft_lengths)) - descriptor = _ducc_fft.ducc_fft_descriptor( - shape=shape if fft_type != FftType.IRFFT else out_shape, + descriptor = _ducc_fft.dynamic_ducc_fft_descriptor( + ndims=ndims, is_double=is_double, fft_type=ducc_fft_type, - fft_lengths=fft_lengths, - strides_in=list(reversed(strides_in)), - strides_out=list(reversed(strides_out)), axes=axes, - forward=forward, - scale=scale) + forward=forward) - return descriptor, out_dtype, out_shape + return descriptor -def ducc_fft_hlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]): - """DUCC FFT kernel for CPU.""" - a_type = ir.RankedTensorType(a.type) - n = len(a_type.shape) +def dynamic_ducc_fft_hlo( + result_type: ir.Type, + input: ir.Value, *, + input_dtype: np.dtype, ndims:int, input_shape: ir.Value, + strides_in: ir.Value, strides_out: ir.Value, scale: ir.Value, + fft_type: FftType, fft_lengths: List[int], result_shape: ir.Value): + """DUCC FFT kernel for CPU, with support for dynamic shapes.""" + a_type = ir.RankedTensorType(input.type) fft_lengths = list(fft_lengths) - descriptor_bytes, out_dtype, out_shape = _ducc_fft_descriptor( - list(a_type.shape), dtype, fft_type, fft_lengths) + descriptor_bytes = _dynamic_ducc_fft_descriptor( + input_dtype, ndims, fft_type, fft_lengths) - if out_dtype == np.float32: - out_type = ir.F32Type.get() - elif out_dtype == np.float64: - out_type = ir.F64Type.get() - elif out_dtype == np.complex64: - out_type = ir.ComplexType.get(ir.F32Type.get()) - elif out_dtype == np.complex128: - out_type = ir.ComplexType.get(ir.F64Type.get()) - else: - raise ValueError(f"Unknown output type {out_dtype}") - - if 0 in a_type.shape or 0 in out_shape: - zero = hlo.ConstantOp( - ir.DenseElementsAttr.get( - np.array(0, dtype=out_dtype), type=out_type)) - return hlo.BroadcastOp( - zero, - ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result + # PocketFft does not allow size 0 dimensions, but we handled this in fft.py + assert 0 not in a_type.shape u8_type = ir.IntegerType.get_unsigned(8) descriptor = hlo.ConstantOp( ir.DenseElementsAttr.get( - np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) - layout = tuple(range(n - 1, -1, -1)) + np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)).result + layout = tuple(range(ndims - 1, -1, -1)) return custom_call( - "ducc_fft", - [ir.RankedTensorType.get(out_shape, out_type)], - [descriptor, a], - operand_layouts=[[0], layout], - result_layouts=[layout]) + "dynamic_ducc_fft", + [result_type], + [descriptor, input, input_shape, strides_in, strides_out, scale], + operand_layouts=[[0], layout, [0], [0], [0], [0]], + result_layouts=[layout], + result_shapes=[result_shape])