Add support for dynamic shapes to jax.fft.

The idea is that we take all the values that can contain dimension sizes
from the descriptor (shape, strides_in, strides_out) and we pass them as
1-d tensor operands. We also pass as an operand the output_shape, so that
we can use the hlo.CustomCallOp `indices_of_output_shapes` attribute to
tell the shape refinement how to compute the shape of the result.

We keep the old descriptor and the ducc_fft registration for the old
C++ custom targets for backwards compatibility (for 6 months). That behavior
is tested by back_compat_test.py.

The one downside of this implementation is that it moves some of the
ducc-specific logic from ducc_fft.py (in jaxlib) into fft.py (in jax). This
was necessary because that code computes with dimensions that are now
dynamic. In JAX we have support for evaluating dynamic shapes and turning
them into 1-d tensors.

Also added backwards compatibility test for dynamic_ducc_fft and kept the
old test for ducc_fft.

PiperOrigin-RevId: 541168692
This commit is contained in:
George Necula 2023-06-17 04:50:12 -07:00 committed by jax authors
parent 3adfe321b0
commit b9c0658fcf
12 changed files with 310 additions and 154 deletions

View File

@ -30,6 +30,7 @@ from jax._src.interpreters import mlir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft
from jax._src.lib import version as jaxlib_version
from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact
__all__ = [
@ -103,6 +104,9 @@ def fft_abstract_eval(x, fft_type, fft_lengths):
return x.update(shape=shape, dtype=dtype)
def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
if not is_constant_shape(fft_lengths):
# TODO: https://github.com/openxla/stablehlo/issues/1366
raise NotImplementedError("Shape polymorphism for FFT with non-constant fft_length is not implemented for TPU and GPU")
return [
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
mlir.dense_int_elements(fft_lengths)).result
@ -110,12 +114,81 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778")
x_aval, = ctx.avals_in
return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type,
if jaxlib_version < (0, 4, 13):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778; try updating your jaxlib.")
return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type, # type: ignore
fft_lengths=fft_lengths)]
in_shape = x_aval.shape
dtype = x_aval.dtype
out_aval, = ctx.avals_out
out_shape = out_aval.shape
forward = fft_type in (xla_client.FftType.FFT, xla_client.FftType.RFFT)
ndims = len(in_shape)
assert len(fft_lengths) >= 1
assert len(fft_lengths) <= ndims, (fft_lengths, ndims)
assert len(in_shape) == len(out_shape) == ndims
# PocketFft does not allow size 0 dimensions.
if 0 in in_shape or 0 in out_shape:
if fft_type == xla_client.FftType.RFFT:
assert dtype in (np.float32, np.float64), dtype
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
elif fft_type == xla_client.FftType.IRFFT:
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
else:
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = dtype
zero = mlir.ir_constant(np.array(0, dtype=out_dtype),
canonicalize_types=False)
return [
mlir.broadcast_in_dim(ctx, zero, out_aval, broadcast_dimensions=[])]
strides_in = []
stride = 1
for d in reversed(in_shape):
strides_in.append(stride)
stride *= d
strides_in = mlir.shape_tensor(
mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_in))))
strides_out = []
stride = 1
for d in reversed(out_shape):
strides_out.append(stride)
stride *= d
strides_out = mlir.shape_tensor(
mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_out))))
# scale = 1. if forward else (1. / np.prod(fft_lengths)) as a f64[1] tensor
double_type = mlir.ir.RankedTensorType.get((), mlir.ir.F64Type.get())
size_fft_length_prod = np.prod(fft_lengths) if fft_lengths else 1
size_fft_lengths, = mlir.eval_dynamic_shape_as_vals(ctx, (size_fft_length_prod,))
size_fft_lengths = hlo.ConvertOp(double_type, size_fft_lengths)
one = mlir.ir_constant(np.float64(1.), canonicalize_types=False)
scale = one if forward else hlo.DivOp(one, size_fft_lengths)
scale = hlo.ReshapeOp(
mlir.ir.RankedTensorType.get((1,), mlir.ir.F64Type.get()),
scale).result
in_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, in_shape))
out_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, out_shape))
in_shape = in_shape if fft_type != xla_client.FftType.IRFFT else out_shape
result_type = mlir.aval_to_ir_type(out_aval)
return [ducc_fft.dynamic_ducc_fft_hlo(
result_type, x,
input_dtype=x_aval.dtype, ndims=ndims, input_shape=in_shape,
strides_in=strides_in, strides_out=strides_out, scale=scale,
fft_type=fft_type, fft_lengths=fft_lengths, result_shape=out_shape)]
def _naive_rfft(x, fft_lengths):
y = fft(x, xla_client.FftType.FFT, fft_lengths)
n = fft_lengths[-1]

View File

@ -687,7 +687,7 @@ def _check_lowering(lowering) -> None:
# Their backwards compatibility is tested by back_compat_test.py.
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"ducc_fft", "cu_threefry2x32",
"ducc_fft", "dynamic_ducc_fft", "cu_threefry2x32",
# eigh on CPU
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
# eigh on GPU

View File

@ -194,7 +194,8 @@ class CompatTest(jtu.JaxTestCase):
atol: Optional[float] = None,
allow_additional_custom_call_targets: Sequence[str] = (),
check_results: Optional[Callable[..., None]] = None,
use_tf_graph: bool = False):
use_tf_graph=False,
compare_with_current: bool = True):
"""Run one compatibility test.
Args:
@ -203,13 +204,21 @@ class CompatTest(jtu.JaxTestCase):
rtol: relative tolerance for numerical comparisons
atol: absolute tolerance for numerical comparisons
check_results: invoked with the results obtained from running the
serialized code, and those stored in the test data, and the kwarg rtol.
serialized code, and those stored in the test data, and the kwargs rtol
and atol.
allow_additional_custom_call_targets: additional custom call targets to allow.
use_tf_graph: if False (default), uses jax_export to serialize JAX
functions and to invoke them. If True then uses tf.Graph to serialize
and run the functions; expects that `func` contains a `jax2tf.call_tf`
and uses `jax2tf.convert` to generate a tf.Graph containing a
XlaCallModule with the actual MLIR module.
compare_with_current: whether to compare the current behavior for
`func` with the one stored in `data`. If `True` (default) uses the
current version of JAX and XLA to lower and serialize `func` and check
its results compared to the stored ones; it also dumps the current
test data. If `False`, no current serialization are comparisons are
done, tests only the saved serialization. Use this option for a test
data for which we have changed the serialization.
"""
if not isinstance(data, CompatTestData):
raise ValueError(f"Expecting data: CompatTestData but got {data}. "
@ -324,6 +333,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
else:
self.assertAllClose(res_from_serialized_run_now, data.expected_outputs,
rtol=rtol, atol=atol)
if compare_with_current:
self.assertListEqual(custom_call_targets, data.custom_call_targets)
def run_serialized(self, data: CompatTestData,
@ -398,7 +408,8 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
# Add here all the testdatas that should cover the targets guaranteed
# stable
covering_testdatas = [
cpu_ducc_fft.data_2023_03_17, cpu_lapack_syev.data_2023_03_17,
cpu_ducc_fft.data_2023_03_17, cpu_ducc_fft.data_2023_06_14,
cpu_lapack_syev.data_2023_03_17,
cpu_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15,
cuda_cusolver_geqrf.data_2023_03_18, cuda_cusolver_syev.data_2023_03_17,
tf_call_tf_function.data_2023_06_02,
@ -422,9 +433,16 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
def func(x):
return lax.fft(x, fft_type="fft", fft_lengths=(4,))
# An old lowering, with ducc_fft. We keep it for 6 months.
data = load_testdata(cpu_ducc_fft.data_2023_03_17)
# We have changed the lowering for fft, do not compare with current.
self.run_one_test(func, data, compare_with_current=False)
# A newer lowering, with dynamic_ducc_fft.
data = load_testdata(cpu_ducc_fft.data_2023_06_14)
self.run_one_test(func, data)
@staticmethod
def eigh_input(shape, dtype):
# In order to keep inputs small, we construct the input programmatically

View File

@ -44,3 +44,71 @@ module @jit_func {
mlir_module_serialized=b'ML\xefR\x03MLIRxxx-trunk\x00\x01\x1d\x05\x01\x05\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\x99s\x15\x01?\x07\x0b\x0f\x17\x0b\x0b\x0b\x0b\x0f\x13\x0b33\x0b\x0b\x0f\x0b\x13\x0b\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x035\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0bf\x04\x0b\x0b\x0b\x13/\x0f\x03\x15\x17\x17\x07\x17\x07\x17\x0b\x07\x13\x13\x02\xca\x05\x1f\x05\x13\x1d\x1b\x07\x17\x1d^\x03\x01\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\'\x07\x03\x03\x03\x15\x05\x1d\x03\x0b\tK\x0b?\rW\x03]\x0f_\x03\x0b\tC\x0b?\rC\x03E\x0fc\x05\x1f\x05!\x1d!\x07\x05#\x03\x03%e\x05%\x05\'\x03\x11+g-A/i1G3k5m7G9q\x05)\x05+\x05-\x05/\x051\x053\x055\x057\x03\x03=E\x059#\x0b\x1d;\x03\x03a\x1d=\x03\x01\x1f\x13!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03M\r\x05OQSU\x1d?\x1dA\x1dC\x1dE\x03\x03Y\r\x03[A\x1dG\x1dI\x1dK\r\x01\x1dM\x1f\x07"\x02\x18\x00\x00\x00\x14\x00$\x00\x00\x00\x00\x00\x08\x00\x0c\x00\x10\x00\x14\x00\x07\x00\x18\x00\x14\x00\x00\x00\x00\x00\x00\x01T\x00\x00\x008\x00\x00\x00\x1c\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x05\x1dO\x05\x01\x03\x05oI\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03I)\x05\r\x11\r)\x05\r\x11\x05\t)\x03B\x04\x0f\x13\x11\x03\x03\x03\x01\x03\x05!)\x03\x05\t)\x03\t\t\x04\x8f\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x17\x05\x03\x05\x0b\x03\x03\x01\r\x07\x05;\x03\x01\x03\x01\x05\x04\x01\x03\x03\x03\x11\x05\x19\x05\x03\t\x13\x03\x03\x01\x07\x06\x1f\x03\x01\x03\x01\t\x03\x11#\x03\x07\x0b\x07\x11)\x03\x01\x05\x05\x03\x05\x04\x05\x03\x07\x06\x03\x01\x05\x01\x00\xc2\x0eQ\x13\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!!)#\x1f\x19\x91\r\xaf\x83\x82\x04\x13\x1f\x15\x1d\x15\x13\x11\x1f\x19\x17\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00convert_v1\x00constant_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00value\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00ducc_fft\x00',
xla_call_module_version=4,
) # End paste
# Pasted from the test output (see back_compat_test.py module docstring)
data_2023_06_14 = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['dynamic_ducc_fft'],
serialized_date=datetime.date(2023, 6, 14),
inputs=(array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]], dtype=float32),),
expected_outputs=(array([[ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j],
[22.+0.j, -2.+2.j, -2.+0.j, -2.-2.j],
[38.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]], dtype=complex64),),
mlir_module_text=r"""
#loc = loc(unknown)
module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<3x4xcomplex<f32>> {jax.result_info = ""}) {
%0 = call @fft(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xcomplex<f32>> loc(#loc3)
return %0 : tensor<3x4xcomplex<f32>> loc(#loc)
} loc(#loc)
func.func private @fft(%arg0: tensor<3x4xf32> loc(unknown)) -> tensor<3x4xcomplex<f32>> {
%0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xcomplex<f32>> loc(#loc4)
%1 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
%2 = stablehlo.constant dense<1> : tensor<i64> loc(#loc5)
%3 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32> loc(#loc5)
%4 = stablehlo.reshape %3 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
%5 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32> loc(#loc5)
%6 = stablehlo.reshape %5 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
%7 = stablehlo.concatenate %4, %6, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5)
%8 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
%9 = stablehlo.constant dense<1> : tensor<i64> loc(#loc5)
%10 = stablehlo.convert %8 : (tensor<i64>) -> tensor<i32> loc(#loc5)
%11 = stablehlo.reshape %10 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
%12 = stablehlo.convert %9 : (tensor<i64>) -> tensor<i32> loc(#loc5)
%13 = stablehlo.reshape %12 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
%14 = stablehlo.concatenate %11, %13, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5)
%15 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
%16 = stablehlo.convert %15 : (tensor<i64>) -> tensor<f64> loc(#loc5)
%17 = stablehlo.constant dense<1.000000e+00> : tensor<f64> loc(#loc5)
%18 = stablehlo.reshape %17 : (tensor<f64>) -> tensor<1xf64> loc(#loc5)
%19 = stablehlo.constant dense<3> : tensor<i64> loc(#loc5)
%20 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
%21 = stablehlo.convert %19 : (tensor<i64>) -> tensor<i32> loc(#loc5)
%22 = stablehlo.reshape %21 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
%23 = stablehlo.convert %20 : (tensor<i64>) -> tensor<i32> loc(#loc5)
%24 = stablehlo.reshape %23 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
%25 = stablehlo.concatenate %22, %24, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5)
%26 = stablehlo.constant dense<3> : tensor<i64> loc(#loc5)
%27 = stablehlo.constant dense<4> : tensor<i64> loc(#loc5)
%28 = stablehlo.convert %26 : (tensor<i64>) -> tensor<i32> loc(#loc5)
%29 = stablehlo.reshape %28 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
%30 = stablehlo.convert %27 : (tensor<i64>) -> tensor<i32> loc(#loc5)
%31 = stablehlo.reshape %30 : (tensor<i32>) -> tensor<1xi32> loc(#loc5)
%32 = stablehlo.concatenate %29, %31, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5)
%33 = stablehlo.constant dense<[20, 0, 0, 0, 0, 0, 14, 0, 16, 0, 8, 0, 0, 0, 0, 0, 12, 0, 7, 0, 14, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]> : tensor<44xui8> loc(#loc5)
%34 = stablehlo.custom_call @dynamic_ducc_fft(%33, %0, %25, %7, %14, %18, %32) {api_version = 2 : i32, indices_of_shape_operands = dense<6> : tensor<1xi64>, operand_layouts = [dense<0> : tensor<1xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<44xui8>, tensor<3x4xcomplex<f32>>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xf64>, tensor<2xi32>) -> tensor<3x4xcomplex<f32>> loc(#loc5)
return %34 : tensor<3x4xcomplex<f32>> loc(#loc3)
} loc(#loc3)
} loc(#loc)
#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":437:0)
#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":428:0)
#loc3 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]"(#loc1))
#loc4 = loc("jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]"(#loc2))
#loc5 = loc("jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01!\x05\x01\x03\x01\x03\x05\x03\x11\x07\t\x0b\r\x0f\x11\x13\x15\x03\xd3\x95+\x01U\x0f\x07\x13\x0b\x13\x0b\x0f\x0f\x0b\x0b\x0b\x0b\x0b\x17\x13\x13#\x0b\x0b\x0b33\x0b\x17\x0f\x0b\x0b\x0b\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x03A/\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b//\x0f//\xbf\x0b\x0b\x0b/'\x0f\x01\x03\x0f\x03)\x0f\x0f\x13\x17\x13\x17\x07\x07\x0f\x07\x07\x13\x07\x17\x0b\x13\x07\x13\x13\x13\x02r\x06\x1d5\x1b\x1f\x03\x03\x07}\x05\x17\x03\x037\x81\x05\x19\x1d-/\x11\x01\x05\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x17\x19\xb2\x06\x01\x03\x03\x07\x7f\x03\x03\x07\x85\x03\x07#\x0f%\x0f\x0b'\x05%\x05'\x05)\x03\x0b\x11c\x13W\x15o\x0bu\x17w\x03\x0b\x11[\x13W\x15[\x0b]\x17{\x05+\x17\x19\xd6\x06\x01\x1d3\x1b\x05-\x05/\x051\x03\x03\x07\x83\x03\x03\x07\x87\x03\x13?\x89AYC\x8bE_G\x8dI\x8fK\x91M_O\x93\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03S]\x05E\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1d\x1dG\x03\x03y\x1dI\x03\x01\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03e\r\x05gikm\x1dK\x1dM\x1dO\x1dQ\x03\x03q\r\x03sY\x1dS\x1dU\x1dW\r\x01\x1dY\x1f\x03\x11\x04\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0f\x01\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf0?\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19Y\x14\x00\x00\x00\x00\x00\x0e\x00\x10\x00\x08\x00\x00\x00\x00\x00\x0c\x00\x07\x00\x0e\x00\x00\x00\x00\x00\x00\x01\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x0b\x05\x1d[\x05\x01\x1f%\x11\x06\x00\x00\x00\x00\x00\x00\x00\x03\x0fUaUUUUU\x03\x03a\x01\x02\x02)\x01\x0f)\x01\x11)\x03\x05\x11)\x05\r\x11\x1f)\x03\t\x11)\x05\r\x11\x15\x1d\x1b)\x01\x17\t\x0b)\x03\xb1#\x13\x11\x03\r\x03\t\x03\x15)\x03\x05\x17!)\x03\x05\x0f)\x03\x05\x1b)\x03\t\x1b\x04\xaa\x04\x05\x01\x11\x03!\x07\x03\x01\t\x0b\x11\x03)\x05\x03\x05\x0b\x03\r\x03\x11\x07\rQ\x03\t\x03\x01\r\x04\x03\x03\x03\x0b\x11\r+\x05\x03I\x93\x03\r\x03\x05\x061\x03\t\x03\x01\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x05\x07\x06\x01\x03\x07\x03\t\x05\x06\x01\x03\x05\x03\x07\x07\x06\x01\x03\x07\x03\r\t\x07\x01\t\x03\x0b\x05\x0b\x0f\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x13\x07\x06\x01\x03\x07\x03\x17\x05\x06\x01\x03\x05\x03\x15\x07\x06\x01\x03\x07\x03\x1b\t\x07\x01\t\x03\x0b\x05\x19\x1d\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x13\x03!\x03\x03\x019\x03\x13\x07\x06\x01\x03!\x03%\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x03)\x07\x06\x01\x03\x07\x03-\x05\x06\x01\x03\x05\x03+\x07\x06\x01\x03\x07\x031\t\x07\x01\t\x03\x0b\x05/3\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x037\x07\x06\x01\x03\x07\x03;\x05\x06\x01\x03\x05\x039\x07\x06\x01\x03\x07\x03?\t\x07\x01\t\x03\x0b\x05=A\x03\x03\x01;\x03\x19\x0f\x07\x01=\x03\t\x0fE\x035\x11\x1f'C\r\x04\r\x03G\x06\x03\x01\x05\x01\x00\xc6\x0e]#\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!5!)#\x1f\x19\x15\x91\xaf\xbe\x02\x13%)\x83\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x11\x1f\x17\x17\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00convert_v1\x00reshape_v1\x00concatenate_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00indices_of_shape_operands\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00dynamic_ducc_fft\x00",
xla_call_module_version=6,
) # End paste

View File

@ -14,7 +14,7 @@
"""Tests for JAX primitive coverage.
The bulk of the testing is done by `test_prim`, which is parameterized by
about 2000+ test harnesses. See `primitive_harness.py` docstring for a
about 3500+ test harnesses. See `primitive_harness.py` docstring for a
description of test harnesses. That module contains also the definitions
of all the test harnesses, and a specification of which are only partially
implemented for JAX.

View File

@ -2742,7 +2742,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# Set of harness.group_name:platform that are implemented with custom call
custom_call_harnesses = {
"vmap_cholesky:cpu", "vmap_cholesky:gpu",
"vmap_fft:cpu", "fft:cpu",
"householder_product:cpu", "householder_product:gpu",
"vmap_geqrf:cpu", "vmap_geqrf:gpu",
"vmap_lu:cpu", "vmap_lu:gpu",

View File

@ -105,7 +105,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"lapack_cgees", ComplexGees<std::complex<float>>::Kernel, "Host");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"lapack_zgees", ComplexGees<std::complex<double>>::Kernel, "Host");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("ducc_fft", DuccFft, "Host");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"ducc_fft", DuccFft, "Host");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"dynamic_ducc_fft", DynamicDuccFft, "Host");
} // namespace
} // namespace jax

View File

@ -27,41 +27,39 @@ namespace py = pybind11;
namespace jax {
namespace {
py::bytes BuildDuccFftDescriptor(const std::vector<uint64_t> &shape,
py::bytes BuildDynamicDuccFftDescriptor(
const uint32_t ndims,
bool is_double, int fft_type,
const std::vector<uint64_t> &fft_lengths,
const std::vector<uint64_t> &strides_in,
const std::vector<uint64_t> &strides_out,
const std::vector<uint32_t> &axes,
bool forward, double scale) {
DuccFftDescriptorT descriptor;
descriptor.shape = shape;
bool forward) {
DynamicDuccFftDescriptorT descriptor;
descriptor.ndims = ndims;
descriptor.fft_type = static_cast<DuccFftType>(fft_type);
descriptor.dtype =
is_double ? DuccFftDtype_COMPLEX128 : DuccFftDtype_COMPLEX64;
descriptor.strides_in = strides_in;
descriptor.strides_out = strides_out;
descriptor.axes = axes;
descriptor.forward = forward;
descriptor.scale = scale;
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(DuccFftDescriptor::Pack(fbb, &descriptor));
fbb.Finish(DynamicDuccFftDescriptor::Pack(fbb, &descriptor));
return py::bytes(reinterpret_cast<char *>(fbb.GetBufferPointer()),
fbb.GetSize());
}
py::dict Registrations() {
pybind11::dict dict;
// TODO(b/287702203): this must be kept until EOY 2023 for backwards
// of serialized functions using fft.
dict["ducc_fft"] = EncapsulateFunction(DuccFft);
dict["dynamic_ducc_fft"] = EncapsulateFunction(DynamicDuccFft);
return dict;
}
PYBIND11_MODULE(_ducc_fft, m) {
m.def("registrations", &Registrations);
m.def("ducc_fft_descriptor", &BuildDuccFftDescriptor, py::arg("shape"),
py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"),
py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"),
py::arg("forward"), py::arg("scale"));
m.def("dynamic_ducc_fft_descriptor", &BuildDynamicDuccFftDescriptor,
py::arg("ndims"), py::arg("is_double"), py::arg("fft_type"),
py::arg("axes"), py::arg("forward"));
}
} // namespace

View File

@ -37,4 +37,12 @@ table DuccFftDescriptor {
scale:double;
}
table DynamicDuccFftDescriptor {
ndims:uint32;
dtype:DuccFftDtype;
fft_type:DuccFftType;
axes:[uint32];
forward:bool;
}
root_type DuccFftDescriptor;

View File

@ -25,76 +25,116 @@ namespace jax {
using shape_t = ducc0::fmav_info::shape_t;
using stride_t = ducc0::fmav_info::stride_t;
void DuccFft(void *out, void **in, XlaCustomCallStatus *) {
const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]);
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
stride_t stride_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
stride_t stride_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
namespace {
switch (descriptor->fft_type()) {
void DuccFftImpl(void *out, void *operand, jax::DuccFftDtype dtype,
jax::DuccFftType fft_type,
shape_t shape, stride_t strides_in, stride_t strides_out, shape_t axes,
bool forward, double scale) {
switch (fft_type) {
case DuccFftType_C2C:
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
if (dtype == DuccFftDtype_COMPLEX64) {
ducc0::cfmav<std::complex<float>> m_in(
reinterpret_cast<std::complex<float> *>(in[1]), shape, stride_in);
reinterpret_cast<std::complex<float> *>(operand), shape, strides_in);
ducc0::vfmav<std::complex<float>> m_out(
reinterpret_cast<std::complex<float> *>(out), shape, stride_out);
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
reinterpret_cast<std::complex<float> *>(out), shape, strides_out);
ducc0::c2c(m_in, m_out, axes, forward, static_cast<float>(scale));
} else {
ducc0::cfmav<std::complex<double>> m_in(
reinterpret_cast<std::complex<double> *>(in[1]), shape, stride_in);
reinterpret_cast<std::complex<double> *>(operand), shape, strides_in);
ducc0::vfmav<std::complex<double>> m_out(
reinterpret_cast<std::complex<double> *>(out), shape, stride_out);
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
reinterpret_cast<std::complex<double> *>(out), shape, strides_out);
ducc0::c2c(m_in, m_out, axes, forward, scale);
}
break;
case DuccFftType_C2R:
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
if (dtype == DuccFftDtype_COMPLEX64) {
auto shape_in = shape;
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
ducc0::cfmav<std::complex<float>> m_in(
reinterpret_cast<std::complex<float> *>(in[1]), shape_in, stride_in);
reinterpret_cast<std::complex<float> *>(operand),
shape_in, strides_in);
ducc0::vfmav<float> m_out(reinterpret_cast<float *>(out), shape,
stride_out);
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
strides_out);
ducc0::c2r(m_in, m_out, axes, forward, static_cast<float>(scale));
} else {
auto shape_in = shape;
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
ducc0::cfmav<std::complex<double>> m_in(
reinterpret_cast<std::complex<double> *>(in[1]), shape_in, stride_in);
reinterpret_cast<std::complex<double> *>(operand),
shape_in, strides_in);
ducc0::vfmav<double> m_out(reinterpret_cast<double *>(out), shape,
stride_out);
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
strides_out);
ducc0::c2r(m_in, m_out, axes, forward, scale);
}
break;
case DuccFftType_R2C:
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
if (dtype == DuccFftDtype_COMPLEX64) {
auto shape_out = shape;
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
ducc0::cfmav<float> m_in(reinterpret_cast<float *>(in[1]), shape,
stride_in);
ducc0::cfmav<float> m_in(reinterpret_cast<float *>(operand), shape,
strides_in);
ducc0::vfmav<std::complex<float>> m_out(
reinterpret_cast<std::complex<float> *>(out), shape_out, stride_out);
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
reinterpret_cast<std::complex<float> *>(out),
shape_out, strides_out);
ducc0::r2c(m_in, m_out, axes, forward, static_cast<float>(scale));
} else {
auto shape_out = shape;
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
ducc0::cfmav<double> m_in(reinterpret_cast<double *>(in[1]), shape,
stride_in);
ducc0::cfmav<double> m_in(reinterpret_cast<double *>(operand), shape,
strides_in);
ducc0::vfmav<std::complex<double>> m_out(
reinterpret_cast<std::complex<double> *>(out), shape_out, stride_out);
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
reinterpret_cast<std::complex<double> *>(out),
shape_out, strides_out);
ducc0::r2c(m_in, m_out, axes, forward, scale);
}
break;
}
}
} // namespace
// TODO(b/287702203): this must be kept until EOY 2023 for backwards
// of serialized functions using fft.
void DuccFft(void *out, void **in, XlaCustomCallStatus *) {
const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]);
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
stride_t strides_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
stride_t strides_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(),
shape, strides_in, strides_out, axes,
descriptor->forward(), descriptor->scale());
}
void DynamicDuccFft(void *out, void **in, XlaCustomCallStatus *) {
// in[0]=descriptor, in[1]=operand,
// in[2]=shape, in[3]=strides_in, in[4]=strides_out, in[5]=scale.
const DynamicDuccFftDescriptor *descriptor =
flatbuffers::GetRoot<DynamicDuccFftDescriptor>(in[0]);
const std::uint32_t *dynamic_shape =
reinterpret_cast<const std::uint32_t*>(in[2]);
shape_t shape(dynamic_shape, dynamic_shape + descriptor->ndims());
const std::uint32_t *dynamic_strides_in =
reinterpret_cast<const std::uint32_t*>(in[3]);
stride_t strides_in(dynamic_strides_in,
dynamic_strides_in + descriptor->ndims());
const std::uint32_t *dynamic_strides_out =
reinterpret_cast<const std::uint32_t*>(in[4]);
stride_t strides_out(dynamic_strides_out,
dynamic_strides_out + descriptor->ndims());
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
const double *dynamic_scale = reinterpret_cast<const double*>(in[5]);
DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(),
shape, strides_in, strides_out, axes,
descriptor->forward(), *dynamic_scale);
}
} // namespace jax

View File

@ -20,8 +20,12 @@ limitations under the License.
namespace jax {
// TODO(b/287702203): this must be kept until EOY 2023 for backwards
// of serialized functions using fft.
void DuccFft(void* out, void** in, XlaCustomCallStatus*);
void DynamicDuccFft(void* out, void** in, XlaCustomCallStatus*);
} // namespace jax
#endif // JAXLIB_CPU_DUCC_FFT_KERNELS_H_

View File

@ -35,115 +35,60 @@ _C2R = 1
_R2C = 2
def _ducc_fft_descriptor(
shape: List[int], dtype, fft_type: FftType, fft_lengths: List[int]
) -> Tuple[bytes, np.dtype, List[int]]:
n = len(shape)
def _dynamic_ducc_fft_descriptor(
dtype, ndims: int, fft_type: FftType, fft_lengths: List[int]
) -> Tuple[bytes]:
assert len(fft_lengths) >= 1
assert len(fft_lengths) <= n, (fft_lengths, n)
assert len(fft_lengths) <= ndims, (fft_lengths, ndims)
forward = fft_type in (FftType.FFT, FftType.RFFT)
is_double = np.finfo(dtype).dtype == np.float64
if fft_type == FftType.RFFT:
ducc_fft_type = _R2C
assert dtype in (np.float32, np.float64), dtype
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
out_shape = list(shape)
out_shape[-1] = out_shape[-1] // 2 + 1
elif fft_type == FftType.IRFFT:
ducc_fft_type = _C2R
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
assert shape[-len(fft_lengths):-1] == fft_lengths[:-1]
out_shape = list(shape)
out_shape[-1] = fft_lengths[-1]
assert (out_shape[-1] // 2 + 1) == shape[-1]
else:
ducc_fft_type = _C2C
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = dtype
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
out_shape = shape
# PocketFft does not allow size 0 dimensions.
if 0 in shape or 0 in out_shape:
return b"", out_dtype, out_shape
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
# C++ kernel to describe the FFT to perform.
strides_in = []
stride = 1
for d in reversed(shape):
strides_in.append(stride)
stride *= d
axes = [ndims - len(fft_lengths) + d for d in range(len(fft_lengths))]
strides_out = []
stride = 1
for d in reversed(out_shape):
strides_out.append(stride)
stride *= d
axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))]
scale = 1. if forward else (1. / np.prod(fft_lengths))
descriptor = _ducc_fft.ducc_fft_descriptor(
shape=shape if fft_type != FftType.IRFFT else out_shape,
descriptor = _ducc_fft.dynamic_ducc_fft_descriptor(
ndims=ndims,
is_double=is_double,
fft_type=ducc_fft_type,
fft_lengths=fft_lengths,
strides_in=list(reversed(strides_in)),
strides_out=list(reversed(strides_out)),
axes=axes,
forward=forward,
scale=scale)
forward=forward)
return descriptor, out_dtype, out_shape
return descriptor
def ducc_fft_hlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
"""DUCC FFT kernel for CPU."""
a_type = ir.RankedTensorType(a.type)
n = len(a_type.shape)
def dynamic_ducc_fft_hlo(
result_type: ir.Type,
input: ir.Value, *,
input_dtype: np.dtype, ndims:int, input_shape: ir.Value,
strides_in: ir.Value, strides_out: ir.Value, scale: ir.Value,
fft_type: FftType, fft_lengths: List[int], result_shape: ir.Value):
"""DUCC FFT kernel for CPU, with support for dynamic shapes."""
a_type = ir.RankedTensorType(input.type)
fft_lengths = list(fft_lengths)
descriptor_bytes, out_dtype, out_shape = _ducc_fft_descriptor(
list(a_type.shape), dtype, fft_type, fft_lengths)
descriptor_bytes = _dynamic_ducc_fft_descriptor(
input_dtype, ndims, fft_type, fft_lengths)
if out_dtype == np.float32:
out_type = ir.F32Type.get()
elif out_dtype == np.float64:
out_type = ir.F64Type.get()
elif out_dtype == np.complex64:
out_type = ir.ComplexType.get(ir.F32Type.get())
elif out_dtype == np.complex128:
out_type = ir.ComplexType.get(ir.F64Type.get())
else:
raise ValueError(f"Unknown output type {out_dtype}")
if 0 in a_type.shape or 0 in out_shape:
zero = hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(0, dtype=out_dtype), type=out_type))
return hlo.BroadcastOp(
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
# PocketFft does not allow size 0 dimensions, but we handled this in fft.py
assert 0 not in a_type.shape
u8_type = ir.IntegerType.get_unsigned(8)
descriptor = hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
layout = tuple(range(n - 1, -1, -1))
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)).result
layout = tuple(range(ndims - 1, -1, -1))
return custom_call(
"ducc_fft",
[ir.RankedTensorType.get(out_shape, out_type)],
[descriptor, a],
operand_layouts=[[0], layout],
result_layouts=[layout])
"dynamic_ducc_fft",
[result_type],
[descriptor, input, input_shape, strides_in, strides_out, scale],
operand_layouts=[[0], layout, [0], [0], [0], [0]],
result_layouts=[layout],
result_shapes=[result_shape])