Switch from pocketfft to ducc

All credit goes to Martin Reinecke <martin@mpa-garching.mpg.de>.
This commit is contained in:
Gordian Edenhofer 2022-08-26 11:37:04 +02:00 committed by Peter Hawkins
parent 4aa96c0e90
commit 024ae47e79
4 changed files with 115 additions and 71 deletions

View File

@ -4332,8 +4332,8 @@ Copyright 2019 The TensorFlow Authors. All rights reserved.
limitations under the License.
--------------------------------------------------------------------------------
License for pocketfft:
Copyright (C) 2010-2018 Max-Planck-Society
License for the FFT components of ducc0:
Copyright (C) 2010-2022 Max-Planck-Society
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,

View File

@ -16,75 +16,107 @@ limitations under the License.
#include <complex>
#include "flatbuffers/flatbuffers.h"
#include "pocketfft/pocketfft_hdronly.h"
#include "jaxlib/pocketfft_generated.h"
#include "pocketfft/src/ducc0/fft/fft.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
void PocketFft(void* out, void** in, XlaCustomCallStatus*) {
const PocketFftDescriptor* descriptor = GetPocketFftDescriptor(in[0]);
pocketfft::shape_t shape(descriptor->shape()->begin(),
descriptor->shape()->end());
pocketfft::stride_t stride_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
pocketfft::stride_t stride_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
pocketfft::shape_t axes(descriptor->axes()->begin(),
descriptor->axes()->end());
using shape_t = ducc0::fmav_info::shape_t;
using stride_t = ducc0::fmav_info::stride_t;
switch (descriptor->fft_type()) {
case PocketFftType_C2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
std::complex<float>* a_in =
reinterpret_cast<std::complex<float>*>(in[1]);
std::complex<float>* a_out =
reinterpret_cast<std::complex<float>*>(out);
pocketfft::c2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
std::complex<double>* a_in =
reinterpret_cast<std::complex<double>*>(in[1]);
std::complex<double>* a_out =
reinterpret_cast<std::complex<double>*>(out);
pocketfft::c2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
case PocketFftType_C2R:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
std::complex<float>* a_in =
reinterpret_cast<std::complex<float>*>(in[1]);
float* a_out = reinterpret_cast<float*>(out);
pocketfft::c2r(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
std::complex<double>* a_in =
reinterpret_cast<std::complex<double>*>(in[1]);
double* a_out = reinterpret_cast<double*>(out);
pocketfft::c2r(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
case PocketFftType_R2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
float* a_in = reinterpret_cast<float*>(in[1]);
std::complex<float>* a_out =
reinterpret_cast<std::complex<float>*>(out);
pocketfft::r2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
double* a_in = reinterpret_cast<double*>(in[1]);
std::complex<double>* a_out =
reinterpret_cast<std::complex<double>*>(out);
pocketfft::r2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
void fixstrides(stride_t &str, size_t size) {
ptrdiff_t ssize = ptrdiff_t(size);
for (auto &s : str) {
auto tmp = s / ssize;
if (tmp * ssize != s)
throw "Bad stride";
s = tmp;
}
}
} // namespace jax
void PocketFft(void *out, void **in, XlaCustomCallStatus *) {
const PocketFftDescriptor *descriptor = GetPocketFftDescriptor(in[0]);
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
stride_t stride_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
stride_t stride_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
switch (descriptor->fft_type()) {
case PocketFftType_C2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
fixstrides(stride_in, sizeof(std::complex<float>));
fixstrides(stride_out, sizeof(std::complex<float>));
ducc0::cfmav<std::complex<float>> m_in(
reinterpret_cast<std::complex<float> *>(in[1]), shape, stride_in);
ducc0::vfmav<std::complex<float>> m_out(
reinterpret_cast<std::complex<float> *>(out), shape, stride_out);
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
} else {
fixstrides(stride_in, sizeof(std::complex<double>));
fixstrides(stride_out, sizeof(std::complex<double>));
ducc0::cfmav<std::complex<double>> m_in(
reinterpret_cast<std::complex<double> *>(in[1]), shape, stride_in);
ducc0::vfmav<std::complex<double>> m_out(
reinterpret_cast<std::complex<double> *>(out), shape, stride_out);
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
}
break;
case PocketFftType_C2R:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
fixstrides(stride_in, sizeof(std::complex<float>));
fixstrides(stride_out, sizeof(float));
auto shape_in = shape;
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
ducc0::cfmav<std::complex<float>> m_in(
reinterpret_cast<std::complex<float> *>(in[1]), shape_in, stride_in);
ducc0::vfmav<float> m_out(reinterpret_cast<float *>(out), shape,
stride_out);
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
} else {
fixstrides(stride_in, sizeof(std::complex<double>));
fixstrides(stride_out, sizeof(double));
auto shape_in = shape;
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
ducc0::cfmav<std::complex<double>> m_in(
reinterpret_cast<std::complex<double> *>(in[1]), shape_in, stride_in);
ducc0::vfmav<double> m_out(reinterpret_cast<double *>(out), shape,
stride_out);
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
}
break;
case PocketFftType_R2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
fixstrides(stride_in, sizeof(float));
fixstrides(stride_out, sizeof(std::complex<float>));
auto shape_out = shape;
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
ducc0::cfmav<float> m_in(reinterpret_cast<float *>(in[1]), shape,
stride_in);
ducc0::vfmav<std::complex<float>> m_out(
reinterpret_cast<std::complex<float> *>(out), shape_out, stride_out);
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
} else {
fixstrides(stride_in, sizeof(double));
fixstrides(stride_out, sizeof(std::complex<double>));
auto shape_out = shape;
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
ducc0::cfmav<double> m_in(reinterpret_cast<double *>(in[1]), shape,
stride_in);
ducc0::vfmav<std::complex<double>> m_out(
reinterpret_cast<std::complex<double> *>(out), shape_out, stride_out);
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
}
break;
}
}
} // namespace jax

View File

@ -4,8 +4,20 @@ package(default_visibility = ["//visibility:public"])
cc_library(
name = "pocketfft",
hdrs = ["pocketfft_hdronly.h"],
copts = ["-fexceptions"],
hdrs = ["src/ducc0/fft/fft.h"],
srcs = ["src/ducc0/fft/fft1d.h",
"src/ducc0/infra/aligned_array.h",
"src/ducc0/infra/error_handling.h",
"src/ducc0/infra/mav.h",
"src/ducc0/infra/simd.h",
"src/ducc0/infra/threading.h",
"src/ducc0/infra/threading.cc",
"src/ducc0/infra/useful_macros.h",
"src/ducc0/math/cmplx.h",
"src/ducc0/math/unity_roots.h",
],
copts = ["-fexceptions", "-ffast-math"],
features = ["-use_header_modules"],
include_prefix = "pocketfft",
includes = ["pocketfft", "src"],
)

View File

@ -19,11 +19,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
def repo():
http_archive(
name = "pocketfft",
sha256 = "66eda977b195965d27aeb9d74f46e0029a6a02e75fbbc47bb554aad68615a260",
strip_prefix = "pocketfft-f800d91ba695b6e19ae2687dd60366900b928002",
strip_prefix = "ducc-356d619a4b5f6f8940d15913c14a043355ef23be",
sha256 = "d23eb2d06f03604867ad40af4fe92dec7cccc2c59f5119e9f01b35b973885c61,
urls = [
"https://github.com/mreineck/pocketfft/archive/f800d91ba695b6e19ae2687dd60366900b928002.tar.gz",
"https://storage.googleapis.com/jax-releases/mirror/pocketfft/pocketfft-f800d91ba695b6e19ae2687dd60366900b928002.tar.gz",
"https://github.com/mreineck/ducc/archive/356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz",
"https://storage.googleapis.com/jax-releases/mirror/ducc/ducc-356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz",
],
build_file = "@//third_party/pocketfft:BUILD.bazel",
)