mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add support for dynamic shapes to jax.fft.
The idea is that we take all the values that can contain dimension sizes from the descriptor (shape, strides_in, strides_out) and we pass them as 1-d tensor operands. We also pass as an operand the output_shape, so that we can use the hlo.CustomCallOp `indices_of_output_shapes` attribute to tell the shape refinement how to compute the shape of the result. We keep the old descriptor and the ducc_fft registration for the old C++ custom targets for backwards compatibility (for 6 months). That behavior is tested by back_compat_test.py. The one downside of this implementation is that it moves some of the ducc-specific logic from ducc_fft.py (in jaxlib) into fft.py (in jax). This was necessary because that code computes with dimensions that are now dynamic. In JAX we have support for evaluating dynamic shapes and turning them into 1-d tensors. Also added backwards compatibility test for dynamic_ducc_fft and kept the old test for ducc_fft. PiperOrigin-RevId: 541168692
This commit is contained in:
parent
3adfe321b0
commit
b9c0658fcf
@ -30,6 +30,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import ducc_fft
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact
|
||||
|
||||
__all__ = [
|
||||
@ -103,6 +104,9 @@ def fft_abstract_eval(x, fft_type, fft_lengths):
|
||||
return x.update(shape=shape, dtype=dtype)
|
||||
|
||||
def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
|
||||
if not is_constant_shape(fft_lengths):
|
||||
# TODO: https://github.com/openxla/stablehlo/issues/1366
|
||||
raise NotImplementedError("Shape polymorphism for FFT with non-constant fft_length is not implemented for TPU and GPU")
|
||||
return [
|
||||
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
|
||||
mlir.dense_int_elements(fft_lengths)).result
|
||||
@ -110,12 +114,81 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
|
||||
|
||||
|
||||
def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778")
|
||||
x_aval, = ctx.avals_in
|
||||
return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type,
|
||||
if jaxlib_version < (0, 4, 13):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778; try updating your jaxlib.")
|
||||
return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type, # type: ignore
|
||||
fft_lengths=fft_lengths)]
|
||||
|
||||
in_shape = x_aval.shape
|
||||
dtype = x_aval.dtype
|
||||
out_aval, = ctx.avals_out
|
||||
out_shape = out_aval.shape
|
||||
|
||||
forward = fft_type in (xla_client.FftType.FFT, xla_client.FftType.RFFT)
|
||||
ndims = len(in_shape)
|
||||
assert len(fft_lengths) >= 1
|
||||
assert len(fft_lengths) <= ndims, (fft_lengths, ndims)
|
||||
assert len(in_shape) == len(out_shape) == ndims
|
||||
|
||||
# PocketFft does not allow size 0 dimensions.
|
||||
if 0 in in_shape or 0 in out_shape:
|
||||
if fft_type == xla_client.FftType.RFFT:
|
||||
assert dtype in (np.float32, np.float64), dtype
|
||||
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
|
||||
|
||||
elif fft_type == xla_client.FftType.IRFFT:
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
|
||||
|
||||
else:
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
out_dtype = dtype
|
||||
|
||||
zero = mlir.ir_constant(np.array(0, dtype=out_dtype),
|
||||
canonicalize_types=False)
|
||||
return [
|
||||
mlir.broadcast_in_dim(ctx, zero, out_aval, broadcast_dimensions=[])]
|
||||
|
||||
strides_in = []
|
||||
stride = 1
|
||||
for d in reversed(in_shape):
|
||||
strides_in.append(stride)
|
||||
stride *= d
|
||||
strides_in = mlir.shape_tensor(
|
||||
mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_in))))
|
||||
|
||||
strides_out = []
|
||||
stride = 1
|
||||
for d in reversed(out_shape):
|
||||
strides_out.append(stride)
|
||||
stride *= d
|
||||
strides_out = mlir.shape_tensor(
|
||||
mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_out))))
|
||||
|
||||
# scale = 1. if forward else (1. / np.prod(fft_lengths)) as a f64[1] tensor
|
||||
double_type = mlir.ir.RankedTensorType.get((), mlir.ir.F64Type.get())
|
||||
size_fft_length_prod = np.prod(fft_lengths) if fft_lengths else 1
|
||||
size_fft_lengths, = mlir.eval_dynamic_shape_as_vals(ctx, (size_fft_length_prod,))
|
||||
size_fft_lengths = hlo.ConvertOp(double_type, size_fft_lengths)
|
||||
one = mlir.ir_constant(np.float64(1.), canonicalize_types=False)
|
||||
scale = one if forward else hlo.DivOp(one, size_fft_lengths)
|
||||
scale = hlo.ReshapeOp(
|
||||
mlir.ir.RankedTensorType.get((1,), mlir.ir.F64Type.get()),
|
||||
scale).result
|
||||
|
||||
in_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, in_shape))
|
||||
out_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, out_shape))
|
||||
in_shape = in_shape if fft_type != xla_client.FftType.IRFFT else out_shape
|
||||
|
||||
result_type = mlir.aval_to_ir_type(out_aval)
|
||||
return [ducc_fft.dynamic_ducc_fft_hlo(
|
||||
result_type, x,
|
||||
input_dtype=x_aval.dtype, ndims=ndims, input_shape=in_shape,
|
||||
strides_in=strides_in, strides_out=strides_out, scale=scale,
|
||||
fft_type=fft_type, fft_lengths=fft_lengths, result_shape=out_shape)]
|
||||
|
||||
def _naive_rfft(x, fft_lengths):
|
||||
y = fft(x, xla_client.FftType.FFT, fft_lengths)
|
||||
n = fft_lengths[-1]
|
||||
|
@ -687,7 +687,7 @@ def _check_lowering(lowering) -> None:
|
||||
# Their backwards compatibility is tested by back_compat_test.py.
|
||||
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||
"ducc_fft", "cu_threefry2x32",
|
||||
"ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32",
|
||||
# eigh on CPU
|
||||
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
|
||||
# eigh on GPU
|
||||
|
@ -194,7 +194,8 @@ class CompatTest(jtu.JaxTestCase):
|
||||
atol: Optional[float] = None,
|
||||
allow_additional_custom_call_targets: Sequence[str] = (),
|
||||
check_results: Optional[Callable[..., None]] = None,
|
||||
use_tf_graph: bool = False):
|
||||
use_tf_graph=False,
|
||||
compare_with_current: bool = True):
|
||||
"""Run one compatibility test.
|
||||
|
||||
Args:
|
||||
@ -203,13 +204,21 @@ class CompatTest(jtu.JaxTestCase):
|
||||
rtol: relative tolerance for numerical comparisons
|
||||
atol: absolute tolerance for numerical comparisons
|
||||
check_results: invoked with the results obtained from running the
|
||||
serialized code, and those stored in the test data, and the kwarg rtol.
|
||||
serialized code, and those stored in the test data, and the kwargs rtol
|
||||
and atol.
|
||||
allow_additional_custom_call_targets: additional custom call targets to allow.
|
||||
use_tf_graph: if False (default), uses jax_export to serialize JAX
|
||||
functions and to invoke them. If True then uses tf.Graph to serialize
|
||||
and run the functions; expects that `func` contains a `jax2tf.call_tf`
|
||||
and uses `jax2tf.convert` to generate a tf.Graph containing a
|
||||
XlaCallModule with the actual MLIR module.
|
||||
compare_with_current: whether to compare the current behavior for
|
||||
`func` with the one stored in `data`. If `True` (default) uses the
|
||||
current version of JAX and XLA to lower and serialize `func` and check
|
||||
its results compared to the stored ones; it also dumps the current
|
||||
test data. If `False`, no current serialization are comparisons are
|
||||
done, tests only the saved serialization. Use this option for a test
|
||||
data for which we have changed the serialization.
|
||||
"""
|
||||
if not isinstance(data, CompatTestData):
|
||||
raise ValueError(f"Expecting data: CompatTestData but got {data}. "
|
||||
@ -324,6 +333,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
else:
|
||||
self.assertAllClose(res_from_serialized_run_now, data.expected_outputs,
|
||||
rtol=rtol, atol=atol)
|
||||
if compare_with_current:
|
||||
self.assertListEqual(custom_call_targets, data.custom_call_targets)
|
||||
|
||||
def run_serialized(self, data: CompatTestData,
|
||||
@ -398,7 +408,8 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
# Add here all the testdatas that should cover the targets guaranteed
|
||||
# stable
|
||||
covering_testdatas = [
|
||||
cpu_ducc_fft.data_2023_03_17, cpu_lapack_syev.data_2023_03_17,
|
||||
cpu_ducc_fft.data_2023_03_17, cpu_ducc_fft.data_2023_06_14,
|
||||
cpu_lapack_syev.data_2023_03_17,
|
||||
cpu_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15,
|
||||
cuda_cusolver_geqrf.data_2023_03_18, cuda_cusolver_syev.data_2023_03_17,
|
||||
tf_call_tf_function.data_2023_06_02,
|
||||
@ -422,9 +433,16 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
def func(x):
|
||||
return lax.fft(x, fft_type="fft", fft_lengths=(4,))
|
||||
|
||||
# An old lowering, with ducc_fft. We keep it for 6 months.
|
||||
data = load_testdata(cpu_ducc_fft.data_2023_03_17)
|
||||
# We have changed the lowering for fft, do not compare with current.
|
||||
self.run_one_test(func, data, compare_with_current=False)
|
||||
|
||||
# A newer lowering, with dynamic_ducc_fft.
|
||||
data = load_testdata(cpu_ducc_fft.data_2023_06_14)
|
||||
self.run_one_test(func, data)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def eigh_input(shape, dtype):
|
||||
# In order to keep inputs small, we construct the input programmatically
|
||||
|
@ -44,3 +44,71 @@ module @jit_func {
|
||||
mlir_module_serialized=b'ML\xefR\x03MLIRxxx-trunk\x00\x01\x1d\x05\x01\x05\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\x99s\x15\x01?\x07\x0b\x0f\x17\x0b\x0b\x0b\x0b\x0f\x13\x0b33\x0b\x0b\x0f\x0b\x13\x0b\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x035\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0bf\x04\x0b\x0b\x0b\x13/\x0f\x03\x15\x17\x17\x07\x17\x07\x17\x0b\x07\x13\x13\x02\xca\x05\x1f\x05\x13\x1d\x1b\x07\x17\x1d^\x03\x01\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\'\x07\x03\x03\x03\x15\x05\x1d\x03\x0b\tK\x0b?\rW\x03]\x0f_\x03\x0b\tC\x0b?\rC\x03E\x0fc\x05\x1f\x05!\x1d!\x07\x05#\x03\x03%e\x05%\x05\'\x03\x11+g-A/i1G3k5m7G9q\x05)\x05+\x05-\x05/\x051\x053\x055\x057\x03\x03=E\x059#\x0b\x1d;\x03\x03a\x1d=\x03\x01\x1f\x13!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03M\r\x05OQSU\x1d?\x1dA\x1dC\x1dE\x03\x03Y\r\x03[A\x1dG\x1dI\x1dK\r\x01\x1dM\x1f\x07"\x02\x18\x00\x00\x00\x14\x00$\x00\x00\x00\x00\x00\x08\x00\x0c\x00\x10\x00\x14\x00\x07\x00\x18\x00\x14\x00\x00\x00\x00\x00\x00\x01T\x00\x00\x008\x00\x00\x00\x1c\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x05\x1dO\x05\x01\x03\x05oI\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03I)\x05\r\x11\r)\x05\r\x11\x05\t)\x03B\x04\x0f\x13\x11\x03\x03\x03\x01\x03\x05!)\x03\x05\t)\x03\t\t\x04\x8f\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x17\x05\x03\x05\x0b\x03\x03\x01\r\x07\x05;\x03\x01\x03\x01\x05\x04\x01\x03\x03\x03\x11\x05\x19\x05\x03\t\x13\x03\x03\x01\x07\x06\x1f\x03\x01\x03\x01\t\x03\x11#\x03\x07\x0b\x07\x11)\x03\x01\x05\x05\x03\x05\x04\x05\x03\x07\x06\x03\x01\x05\x01\x00\xc2\x0eQ\x13\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!!)#\x1f\x19\x91\r\xaf\x83\x82\x04\x13\x1f\x15\x1d\x15\x13\x11\x1f\x19\x17\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00convert_v1\x00constant_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00value\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00ducc_fft\x00',
|
||||
xla_call_module_version=4,
|
||||
) # End paste
|
||||
|
||||
# Pasted from the test output (see back_compat_test.py module docstring)
|
||||
data_2023_06_14 = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['dynamic_ducc_fft'],
|
||||
serialized_date=datetime.date(2023, 6, 14),
|
||||
inputs=(array([[ 0., 1., 2., 3.],
|
||||
[ 4., 5., 6., 7.],
|
||||
[ 8., 9., 10., 11.]], dtype=float32),),
|
||||
expected_outputs=(array([[ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j],
|
||||
[22.+0.j, -2.+2.j, -2.+0.j, -2.-2.j],
|
||||
[38.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]], dtype=complex64),),
|
||||
mlir_module_text=r"""
|
||||
#loc = loc(unknown)
|
||||
module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<3x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<3x4xcomplex<f32>> {jax.result_info = ""}) {
|
||||
%0 = call @fft(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xcomplex<f32>> loc(#loc3)
|
||||
return %0 : tensor<3x4xcomplex<f32>> loc(#loc)
|
||||
} loc(#loc)
|
||||
func.func private @fft(%arg0: tensor<3x4xf32> loc(unknown)) -> tensor<3x4xcomplex<f32>> {
|
||||
%0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xcomplex<f32>> loc(#loc4)
|
||||
%1 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
|
||||
%2 = stablehlo.constant dense<1> : tensor<i64> loc(#loc5)
|
||||
%3 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32> loc(#loc5)
|
||||
%4 = stablehlo.reshape %3 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
|
||||
%5 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32> loc(#loc5)
|
||||
%6 = stablehlo.reshape %5 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
|
||||
%7 = stablehlo.concatenate %4, %6, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5)
|
||||
%8 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
|
||||
%9 = stablehlo.constant dense<1> : tensor<i64> loc(#loc5)
|
||||
%10 = stablehlo.convert %8 : (tensor<i64>) -> tensor<i32> loc(#loc5)
|
||||
%11 = stablehlo.reshape %10 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
|
||||
%12 = stablehlo.convert %9 : (tensor<i64>) -> tensor<i32> loc(#loc5)
|
||||
%13 = stablehlo.reshape %12 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
|
||||
%14 = stablehlo.concatenate %11, %13, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5)
|
||||
%15 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
|
||||
%16 = stablehlo.convert %15 : (tensor<i64>) -> tensor<f64> loc(#loc5)
|
||||
%17 = stablehlo.constant dense<1.000000e+00> : tensor<f64> loc(#loc5)
|
||||
%18 = stablehlo.reshape %17 : (tensor<f64>) -> tensor<1xf64> loc(#loc5)
|
||||
%19 = stablehlo.constant dense<3> : tensor<i64> loc(#loc5)
|
||||
%20 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
|
||||
%21 = stablehlo.convert %19 : (tensor<i64>) -> tensor<i32> loc(#loc5)
|
||||
%22 = stablehlo.reshape %21 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
|
||||
%23 = stablehlo.convert %20 : (tensor<i64>) -> tensor<i32> loc(#loc5)
|
||||
%24 = stablehlo.reshape %23 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
|
||||
%25 = stablehlo.concatenate %22, %24, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5)
|
||||
%26 = stablehlo.constant dense<3> : tensor<i64> loc(#loc5)
|
||||
%27 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
|
||||
%28 = stablehlo.convert %26 : (tensor<i64>) -> tensor<i32> loc(#loc5)
|
||||
%29 = stablehlo.reshape %28 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
|
||||
%30 = stablehlo.convert %27 : (tensor<i64>) -> tensor<i32> loc(#loc5)
|
||||
%31 = stablehlo.reshape %30 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
|
||||
%32 = stablehlo.concatenate %29, %31, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5)
|
||||
%33 = stablehlo.constant dense<[20, 0, 0, 0, 0, 0, 14, 0, 16, 0, 8, 0, 0, 0, 0, 0, 12, 0, 7, 0, 14, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]> : tensor<44xui8> loc(#loc5)
|
||||
%34 = stablehlo.custom_call @dynamic_ducc_fft(%33, %0, %25, %7, %14, %18, %32) {api_version = 2 : i32, indices_of_shape_operands = dense<6> : tensor<1xi64>, operand_layouts = [dense<0> : tensor<1xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<44xui8>, tensor<3x4xcomplex<f32>>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xf64>, tensor<2xi32>) -> tensor<3x4xcomplex<f32>> loc(#loc5)
|
||||
return %34 : tensor<3x4xcomplex<f32>> loc(#loc3)
|
||||
} loc(#loc3)
|
||||
} loc(#loc)
|
||||
#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":437:0)
|
||||
#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":428:0)
|
||||
#loc3 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]"(#loc1))
|
||||
#loc4 = loc("jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]"(#loc2))
|
||||
#loc5 = loc("jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01!\x05\x01\x03\x01\x03\x05\x03\x11\x07\t\x0b\r\x0f\x11\x13\x15\x03\xd3\x95+\x01U\x0f\x07\x13\x0b\x13\x0b\x0f\x0f\x0b\x0b\x0b\x0b\x0b\x17\x13\x13#\x0b\x0b\x0b33\x0b\x17\x0f\x0b\x0b\x0b\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x03A/\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b//\x0f//\xbf\x0b\x0b\x0b/'\x0f\x01\x03\x0f\x03)\x0f\x0f\x13\x17\x13\x17\x07\x07\x0f\x07\x07\x13\x07\x17\x0b\x13\x07\x13\x13\x13\x02r\x06\x1d5\x1b\x1f\x03\x03\x07}\x05\x17\x03\x037\x81\x05\x19\x1d-/\x11\x01\x05\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x17\x19\xb2\x06\x01\x03\x03\x07\x7f\x03\x03\x07\x85\x03\x07#\x0f%\x0f\x0b'\x05%\x05'\x05)\x03\x0b\x11c\x13W\x15o\x0bu\x17w\x03\x0b\x11[\x13W\x15[\x0b]\x17{\x05+\x17\x19\xd6\x06\x01\x1d3\x1b\x05-\x05/\x051\x03\x03\x07\x83\x03\x03\x07\x87\x03\x13?\x89AYC\x8bE_G\x8dI\x8fK\x91M_O\x93\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03S]\x05E\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1d\x1dG\x03\x03y\x1dI\x03\x01\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03e\r\x05gikm\x1dK\x1dM\x1dO\x1dQ\x03\x03q\r\x03sY\x1dS\x1dU\x1dW\r\x01\x1dY\x1f\x03\x11\x04\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0f\x01\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf0?\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19Y\x14\x00\x00\x00\x00\x00\x0e\x00\x10\x00\x08\x00\x00\x00\x00\x00\x0c\x00\x07\x00\x0e\x00\x00\x00\x00\x00\x00\x01\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x0b\x05\x1d[\x05\x01\x1f%\x11\x06\x00\x00\x00\x00\x00\x00\x00\x03\x0fUaUUUUU\x03\x03a\x01\x02\x02)\x01\x0f)\x01\x11)\x03\x05\x11)\x05\r\x11\x1f)\x03\t\x11)\x05\r\x11\x15\x1d\x1b)\x01\x17\t\x0b)\x03\xb1#\x13\x11\x03\r\x03\t\x03\x15)\x03\x05\x17!)\x03\x05\x0f)\x03\x05\x1b)\x03\t\x1b\x04\xaa\x04\x05\x01\x11\x03!\x07\x03\x01\t\x0b\x11\x03)\x05\x03\x05\x0b\x03\r\x03\x11\x07\rQ\x03\t\x03\x01\r\x04\x03\x03\x03\x0b\x11\r+\x05\x03I\x93\x03\r\x03\x05\x061\x03\t\x03\x01\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x05\x07\x06\x01\x03\x07\x03\t\x05\x06\x01\x03\x05\x03\x07\x07\x06\x01\x03\x07\x03\r\t\x07\x01\t\x03\x0b\x05\x0b\x0f\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x13\x07\x06\x01\x03\x07\x03\x17\x05\x06\x01\x03\x05\x03\x15\x07\x06\x01\x03\x07\x03\x1b\t\x07\x01\t\x03\x0b\x05\x19\x1d\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x13\x03!\x03\x03\x019\x03\x13\x07\x06\x01\x03!\x03%\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x03)\x07\x06\x01\x03\x07\x03-\x05\x06\x01\x03\x05\x03+\x07\x06\x01\x03\x07\x031\t\x07\x01\t\x03\x0b\x05/3\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x037\x07\x06\x01\x03\x07\x03;\x05\x06\x01\x03\x05\x039\x07\x06\x01\x03\x07\x03?\t\x07\x01\t\x03\x0b\x05=A\x03\x03\x01;\x03\x19\x0f\x07\x01=\x03\t\x0fE\x035\x11\x1f'C\r\x04\r\x03G\x06\x03\x01\x05\x01\x00\xc6\x0e]#\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!5!)#\x1f\x19\x15\x91\xaf\xbe\x02\x13%)\x83\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x11\x1f\x17\x17\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00convert_v1\x00reshape_v1\x00concatenate_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00indices_of_shape_operands\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00dynamic_ducc_fft\x00",
|
||||
xla_call_module_version=6,
|
||||
) # End paste
|
||||
|
@ -14,7 +14,7 @@
|
||||
"""Tests for JAX primitive coverage.
|
||||
|
||||
The bulk of the testing is done by `test_prim`, which is parameterized by
|
||||
about 2000+ test harnesses. See `primitive_harness.py` docstring for a
|
||||
about 3500+ test harnesses. See `primitive_harness.py` docstring for a
|
||||
description of test harnesses. That module contains also the definitions
|
||||
of all the test harnesses, and a specification of which are only partially
|
||||
implemented for JAX.
|
||||
|
@ -2742,7 +2742,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# Set of harness.group_name:platform that are implemented with custom call
|
||||
custom_call_harnesses = {
|
||||
"vmap_cholesky:cpu", "vmap_cholesky:gpu",
|
||||
"vmap_fft:cpu", "fft:cpu",
|
||||
"householder_product:cpu", "householder_product:gpu",
|
||||
"vmap_geqrf:cpu", "vmap_geqrf:gpu",
|
||||
"vmap_lu:cpu", "vmap_lu:gpu",
|
||||
|
@ -105,7 +105,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
"lapack_cgees", ComplexGees<std::complex<float>>::Kernel, "Host");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
"lapack_zgees", ComplexGees<std::complex<double>>::Kernel, "Host");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("ducc_fft", DuccFft, "Host");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
"ducc_fft", DuccFft, "Host");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
"dynamic_ducc_fft", DynamicDuccFft, "Host");
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
||||
|
@ -27,41 +27,39 @@ namespace py = pybind11;
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
py::bytes BuildDuccFftDescriptor(const std::vector<uint64_t> &shape,
|
||||
|
||||
py::bytes BuildDynamicDuccFftDescriptor(
|
||||
const uint32_t ndims,
|
||||
bool is_double, int fft_type,
|
||||
const std::vector<uint64_t> &fft_lengths,
|
||||
const std::vector<uint64_t> &strides_in,
|
||||
const std::vector<uint64_t> &strides_out,
|
||||
const std::vector<uint32_t> &axes,
|
||||
bool forward, double scale) {
|
||||
DuccFftDescriptorT descriptor;
|
||||
descriptor.shape = shape;
|
||||
bool forward) {
|
||||
DynamicDuccFftDescriptorT descriptor;
|
||||
descriptor.ndims = ndims;
|
||||
descriptor.fft_type = static_cast<DuccFftType>(fft_type);
|
||||
descriptor.dtype =
|
||||
is_double ? DuccFftDtype_COMPLEX128 : DuccFftDtype_COMPLEX64;
|
||||
descriptor.strides_in = strides_in;
|
||||
descriptor.strides_out = strides_out;
|
||||
descriptor.axes = axes;
|
||||
descriptor.forward = forward;
|
||||
descriptor.scale = scale;
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.Finish(DuccFftDescriptor::Pack(fbb, &descriptor));
|
||||
fbb.Finish(DynamicDuccFftDescriptor::Pack(fbb, &descriptor));
|
||||
return py::bytes(reinterpret_cast<char *>(fbb.GetBufferPointer()),
|
||||
fbb.GetSize());
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
pybind11::dict dict;
|
||||
// TODO(b/287702203): this must be kept until EOY 2023 for backwards
|
||||
// of serialized functions using fft.
|
||||
dict["ducc_fft"] = EncapsulateFunction(DuccFft);
|
||||
dict["dynamic_ducc_fft"] = EncapsulateFunction(DynamicDuccFft);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_ducc_fft, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("ducc_fft_descriptor", &BuildDuccFftDescriptor, py::arg("shape"),
|
||||
py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"),
|
||||
py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"),
|
||||
py::arg("forward"), py::arg("scale"));
|
||||
m.def("dynamic_ducc_fft_descriptor", &BuildDynamicDuccFftDescriptor,
|
||||
py::arg("ndims"), py::arg("is_double"), py::arg("fft_type"),
|
||||
py::arg("axes"), py::arg("forward"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -37,4 +37,12 @@ table DuccFftDescriptor {
|
||||
scale:double;
|
||||
}
|
||||
|
||||
table DynamicDuccFftDescriptor {
|
||||
ndims:uint32;
|
||||
dtype:DuccFftDtype;
|
||||
fft_type:DuccFftType;
|
||||
axes:[uint32];
|
||||
forward:bool;
|
||||
}
|
||||
|
||||
root_type DuccFftDescriptor;
|
||||
|
@ -25,76 +25,116 @@ namespace jax {
|
||||
using shape_t = ducc0::fmav_info::shape_t;
|
||||
using stride_t = ducc0::fmav_info::stride_t;
|
||||
|
||||
void DuccFft(void *out, void **in, XlaCustomCallStatus *) {
|
||||
const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]);
|
||||
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
|
||||
stride_t stride_in(descriptor->strides_in()->begin(),
|
||||
descriptor->strides_in()->end());
|
||||
stride_t stride_out(descriptor->strides_out()->begin(),
|
||||
descriptor->strides_out()->end());
|
||||
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
|
||||
namespace {
|
||||
|
||||
switch (descriptor->fft_type()) {
|
||||
void DuccFftImpl(void *out, void *operand, jax::DuccFftDtype dtype,
|
||||
jax::DuccFftType fft_type,
|
||||
shape_t shape, stride_t strides_in, stride_t strides_out, shape_t axes,
|
||||
bool forward, double scale) {
|
||||
|
||||
switch (fft_type) {
|
||||
case DuccFftType_C2C:
|
||||
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
|
||||
if (dtype == DuccFftDtype_COMPLEX64) {
|
||||
ducc0::cfmav<std::complex<float>> m_in(
|
||||
reinterpret_cast<std::complex<float> *>(in[1]), shape, stride_in);
|
||||
reinterpret_cast<std::complex<float> *>(operand), shape, strides_in);
|
||||
ducc0::vfmav<std::complex<float>> m_out(
|
||||
reinterpret_cast<std::complex<float> *>(out), shape, stride_out);
|
||||
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<float>(descriptor->scale()));
|
||||
reinterpret_cast<std::complex<float> *>(out), shape, strides_out);
|
||||
ducc0::c2c(m_in, m_out, axes, forward, static_cast<float>(scale));
|
||||
} else {
|
||||
ducc0::cfmav<std::complex<double>> m_in(
|
||||
reinterpret_cast<std::complex<double> *>(in[1]), shape, stride_in);
|
||||
reinterpret_cast<std::complex<double> *>(operand), shape, strides_in);
|
||||
ducc0::vfmav<std::complex<double>> m_out(
|
||||
reinterpret_cast<std::complex<double> *>(out), shape, stride_out);
|
||||
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<double>(descriptor->scale()));
|
||||
reinterpret_cast<std::complex<double> *>(out), shape, strides_out);
|
||||
ducc0::c2c(m_in, m_out, axes, forward, scale);
|
||||
}
|
||||
break;
|
||||
case DuccFftType_C2R:
|
||||
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
|
||||
if (dtype == DuccFftDtype_COMPLEX64) {
|
||||
auto shape_in = shape;
|
||||
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<std::complex<float>> m_in(
|
||||
reinterpret_cast<std::complex<float> *>(in[1]), shape_in, stride_in);
|
||||
reinterpret_cast<std::complex<float> *>(operand),
|
||||
shape_in, strides_in);
|
||||
ducc0::vfmav<float> m_out(reinterpret_cast<float *>(out), shape,
|
||||
stride_out);
|
||||
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<float>(descriptor->scale()));
|
||||
strides_out);
|
||||
ducc0::c2r(m_in, m_out, axes, forward, static_cast<float>(scale));
|
||||
} else {
|
||||
auto shape_in = shape;
|
||||
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<std::complex<double>> m_in(
|
||||
reinterpret_cast<std::complex<double> *>(in[1]), shape_in, stride_in);
|
||||
reinterpret_cast<std::complex<double> *>(operand),
|
||||
shape_in, strides_in);
|
||||
ducc0::vfmav<double> m_out(reinterpret_cast<double *>(out), shape,
|
||||
stride_out);
|
||||
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<double>(descriptor->scale()));
|
||||
strides_out);
|
||||
ducc0::c2r(m_in, m_out, axes, forward, scale);
|
||||
}
|
||||
break;
|
||||
case DuccFftType_R2C:
|
||||
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
|
||||
if (dtype == DuccFftDtype_COMPLEX64) {
|
||||
auto shape_out = shape;
|
||||
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<float> m_in(reinterpret_cast<float *>(in[1]), shape,
|
||||
stride_in);
|
||||
ducc0::cfmav<float> m_in(reinterpret_cast<float *>(operand), shape,
|
||||
strides_in);
|
||||
ducc0::vfmav<std::complex<float>> m_out(
|
||||
reinterpret_cast<std::complex<float> *>(out), shape_out, stride_out);
|
||||
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<float>(descriptor->scale()));
|
||||
reinterpret_cast<std::complex<float> *>(out),
|
||||
shape_out, strides_out);
|
||||
ducc0::r2c(m_in, m_out, axes, forward, static_cast<float>(scale));
|
||||
} else {
|
||||
auto shape_out = shape;
|
||||
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<double> m_in(reinterpret_cast<double *>(in[1]), shape,
|
||||
stride_in);
|
||||
ducc0::cfmav<double> m_in(reinterpret_cast<double *>(operand), shape,
|
||||
strides_in);
|
||||
ducc0::vfmav<std::complex<double>> m_out(
|
||||
reinterpret_cast<std::complex<double> *>(out), shape_out, stride_out);
|
||||
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<double>(descriptor->scale()));
|
||||
reinterpret_cast<std::complex<double> *>(out),
|
||||
shape_out, strides_out);
|
||||
ducc0::r2c(m_in, m_out, axes, forward, scale);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
// TODO(b/287702203): this must be kept until EOY 2023 for backwards
|
||||
// of serialized functions using fft.
|
||||
void DuccFft(void *out, void **in, XlaCustomCallStatus *) {
|
||||
const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]);
|
||||
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
|
||||
stride_t strides_in(descriptor->strides_in()->begin(),
|
||||
descriptor->strides_in()->end());
|
||||
stride_t strides_out(descriptor->strides_out()->begin(),
|
||||
descriptor->strides_out()->end());
|
||||
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
|
||||
|
||||
DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(),
|
||||
shape, strides_in, strides_out, axes,
|
||||
descriptor->forward(), descriptor->scale());
|
||||
}
|
||||
|
||||
|
||||
void DynamicDuccFft(void *out, void **in, XlaCustomCallStatus *) {
|
||||
// in[0]=descriptor, in[1]=operand,
|
||||
// in[2]=shape, in[3]=strides_in, in[4]=strides_out, in[5]=scale.
|
||||
const DynamicDuccFftDescriptor *descriptor =
|
||||
flatbuffers::GetRoot<DynamicDuccFftDescriptor>(in[0]);
|
||||
const std::uint32_t *dynamic_shape =
|
||||
reinterpret_cast<const std::uint32_t*>(in[2]);
|
||||
shape_t shape(dynamic_shape, dynamic_shape + descriptor->ndims());
|
||||
const std::uint32_t *dynamic_strides_in =
|
||||
reinterpret_cast<const std::uint32_t*>(in[3]);
|
||||
stride_t strides_in(dynamic_strides_in,
|
||||
dynamic_strides_in + descriptor->ndims());
|
||||
const std::uint32_t *dynamic_strides_out =
|
||||
reinterpret_cast<const std::uint32_t*>(in[4]);
|
||||
stride_t strides_out(dynamic_strides_out,
|
||||
dynamic_strides_out + descriptor->ndims());
|
||||
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
|
||||
const double *dynamic_scale = reinterpret_cast<const double*>(in[5]);
|
||||
|
||||
DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(),
|
||||
shape, strides_in, strides_out, axes,
|
||||
descriptor->forward(), *dynamic_scale);
|
||||
}
|
||||
|
||||
} // namespace jax
|
@ -20,8 +20,12 @@ limitations under the License.
|
||||
|
||||
namespace jax {
|
||||
|
||||
// TODO(b/287702203): this must be kept until EOY 2023 for backwards
|
||||
// of serialized functions using fft.
|
||||
void DuccFft(void* out, void** in, XlaCustomCallStatus*);
|
||||
|
||||
void DynamicDuccFft(void* out, void** in, XlaCustomCallStatus*);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CPU_DUCC_FFT_KERNELS_H_
|
||||
|
@ -35,115 +35,60 @@ _C2R = 1
|
||||
_R2C = 2
|
||||
|
||||
|
||||
def _ducc_fft_descriptor(
|
||||
shape: List[int], dtype, fft_type: FftType, fft_lengths: List[int]
|
||||
) -> Tuple[bytes, np.dtype, List[int]]:
|
||||
n = len(shape)
|
||||
def _dynamic_ducc_fft_descriptor(
|
||||
dtype, ndims: int, fft_type: FftType, fft_lengths: List[int]
|
||||
) -> Tuple[bytes]:
|
||||
assert len(fft_lengths) >= 1
|
||||
assert len(fft_lengths) <= n, (fft_lengths, n)
|
||||
assert len(fft_lengths) <= ndims, (fft_lengths, ndims)
|
||||
|
||||
forward = fft_type in (FftType.FFT, FftType.RFFT)
|
||||
is_double = np.finfo(dtype).dtype == np.float64
|
||||
if fft_type == FftType.RFFT:
|
||||
ducc_fft_type = _R2C
|
||||
|
||||
assert dtype in (np.float32, np.float64), dtype
|
||||
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
|
||||
|
||||
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
|
||||
out_shape = list(shape)
|
||||
out_shape[-1] = out_shape[-1] // 2 + 1
|
||||
|
||||
elif fft_type == FftType.IRFFT:
|
||||
ducc_fft_type = _C2R
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
|
||||
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
|
||||
|
||||
assert shape[-len(fft_lengths):-1] == fft_lengths[:-1]
|
||||
out_shape = list(shape)
|
||||
out_shape[-1] = fft_lengths[-1]
|
||||
assert (out_shape[-1] // 2 + 1) == shape[-1]
|
||||
else:
|
||||
ducc_fft_type = _C2C
|
||||
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
out_dtype = dtype
|
||||
|
||||
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
|
||||
out_shape = shape
|
||||
|
||||
# PocketFft does not allow size 0 dimensions.
|
||||
if 0 in shape or 0 in out_shape:
|
||||
return b"", out_dtype, out_shape
|
||||
|
||||
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
|
||||
# C++ kernel to describe the FFT to perform.
|
||||
strides_in = []
|
||||
stride = 1
|
||||
for d in reversed(shape):
|
||||
strides_in.append(stride)
|
||||
stride *= d
|
||||
axes = [ndims - len(fft_lengths) + d for d in range(len(fft_lengths))]
|
||||
|
||||
strides_out = []
|
||||
stride = 1
|
||||
for d in reversed(out_shape):
|
||||
strides_out.append(stride)
|
||||
stride *= d
|
||||
|
||||
axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))]
|
||||
|
||||
scale = 1. if forward else (1. / np.prod(fft_lengths))
|
||||
descriptor = _ducc_fft.ducc_fft_descriptor(
|
||||
shape=shape if fft_type != FftType.IRFFT else out_shape,
|
||||
descriptor = _ducc_fft.dynamic_ducc_fft_descriptor(
|
||||
ndims=ndims,
|
||||
is_double=is_double,
|
||||
fft_type=ducc_fft_type,
|
||||
fft_lengths=fft_lengths,
|
||||
strides_in=list(reversed(strides_in)),
|
||||
strides_out=list(reversed(strides_out)),
|
||||
axes=axes,
|
||||
forward=forward,
|
||||
scale=scale)
|
||||
forward=forward)
|
||||
|
||||
return descriptor, out_dtype, out_shape
|
||||
return descriptor
|
||||
|
||||
|
||||
def ducc_fft_hlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
"""DUCC FFT kernel for CPU."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
n = len(a_type.shape)
|
||||
def dynamic_ducc_fft_hlo(
|
||||
result_type: ir.Type,
|
||||
input: ir.Value, *,
|
||||
input_dtype: np.dtype, ndims:int, input_shape: ir.Value,
|
||||
strides_in: ir.Value, strides_out: ir.Value, scale: ir.Value,
|
||||
fft_type: FftType, fft_lengths: List[int], result_shape: ir.Value):
|
||||
"""DUCC FFT kernel for CPU, with support for dynamic shapes."""
|
||||
a_type = ir.RankedTensorType(input.type)
|
||||
|
||||
fft_lengths = list(fft_lengths)
|
||||
descriptor_bytes, out_dtype, out_shape = _ducc_fft_descriptor(
|
||||
list(a_type.shape), dtype, fft_type, fft_lengths)
|
||||
descriptor_bytes = _dynamic_ducc_fft_descriptor(
|
||||
input_dtype, ndims, fft_type, fft_lengths)
|
||||
|
||||
if out_dtype == np.float32:
|
||||
out_type = ir.F32Type.get()
|
||||
elif out_dtype == np.float64:
|
||||
out_type = ir.F64Type.get()
|
||||
elif out_dtype == np.complex64:
|
||||
out_type = ir.ComplexType.get(ir.F32Type.get())
|
||||
elif out_dtype == np.complex128:
|
||||
out_type = ir.ComplexType.get(ir.F64Type.get())
|
||||
else:
|
||||
raise ValueError(f"Unknown output type {out_dtype}")
|
||||
|
||||
if 0 in a_type.shape or 0 in out_shape:
|
||||
zero = hlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(0, dtype=out_dtype), type=out_type))
|
||||
return hlo.BroadcastOp(
|
||||
zero,
|
||||
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
|
||||
# PocketFft does not allow size 0 dimensions, but we handled this in fft.py
|
||||
assert 0 not in a_type.shape
|
||||
|
||||
u8_type = ir.IntegerType.get_unsigned(8)
|
||||
descriptor = hlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
|
||||
layout = tuple(range(n - 1, -1, -1))
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)).result
|
||||
layout = tuple(range(ndims - 1, -1, -1))
|
||||
return custom_call(
|
||||
"ducc_fft",
|
||||
[ir.RankedTensorType.get(out_shape, out_type)],
|
||||
[descriptor, a],
|
||||
operand_layouts=[[0], layout],
|
||||
result_layouts=[layout])
|
||||
"dynamic_ducc_fft",
|
||||
[result_type],
|
||||
[descriptor, input, input_shape, strides_in, strides_out, scale],
|
||||
operand_layouts=[[0], layout, [0], [0], [0], [0]],
|
||||
result_layouts=[layout],
|
||||
result_shapes=[result_shape])
|
||||
|
Loading…
x
Reference in New Issue
Block a user