mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Switch from pocketfft to ducc
All credit goes to Martin Reinecke <martin@mpa-garching.mpg.de>.
This commit is contained in:
parent
4aa96c0e90
commit
024ae47e79
@ -4332,8 +4332,8 @@ Copyright 2019 The TensorFlow Authors. All rights reserved.
|
||||
limitations under the License.
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
License for pocketfft:
|
||||
Copyright (C) 2010-2018 Max-Planck-Society
|
||||
License for the FFT components of ducc0:
|
||||
Copyright (C) 2010-2022 Max-Planck-Society
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
|
@ -16,75 +16,107 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "pocketfft/pocketfft_hdronly.h"
|
||||
#include "jaxlib/pocketfft_generated.h"
|
||||
#include "pocketfft/src/ducc0/fft/fft.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
void PocketFft(void* out, void** in, XlaCustomCallStatus*) {
|
||||
const PocketFftDescriptor* descriptor = GetPocketFftDescriptor(in[0]);
|
||||
pocketfft::shape_t shape(descriptor->shape()->begin(),
|
||||
descriptor->shape()->end());
|
||||
pocketfft::stride_t stride_in(descriptor->strides_in()->begin(),
|
||||
descriptor->strides_in()->end());
|
||||
pocketfft::stride_t stride_out(descriptor->strides_out()->begin(),
|
||||
descriptor->strides_out()->end());
|
||||
pocketfft::shape_t axes(descriptor->axes()->begin(),
|
||||
descriptor->axes()->end());
|
||||
using shape_t = ducc0::fmav_info::shape_t;
|
||||
using stride_t = ducc0::fmav_info::stride_t;
|
||||
|
||||
switch (descriptor->fft_type()) {
|
||||
case PocketFftType_C2C:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
std::complex<float>* a_in =
|
||||
reinterpret_cast<std::complex<float>*>(in[1]);
|
||||
std::complex<float>* a_out =
|
||||
reinterpret_cast<std::complex<float>*>(out);
|
||||
pocketfft::c2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out,
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
std::complex<double>* a_in =
|
||||
reinterpret_cast<std::complex<double>*>(in[1]);
|
||||
std::complex<double>* a_out =
|
||||
reinterpret_cast<std::complex<double>*>(out);
|
||||
pocketfft::c2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out, descriptor->scale());
|
||||
}
|
||||
break;
|
||||
case PocketFftType_C2R:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
std::complex<float>* a_in =
|
||||
reinterpret_cast<std::complex<float>*>(in[1]);
|
||||
float* a_out = reinterpret_cast<float*>(out);
|
||||
pocketfft::c2r(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out,
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
std::complex<double>* a_in =
|
||||
reinterpret_cast<std::complex<double>*>(in[1]);
|
||||
double* a_out = reinterpret_cast<double*>(out);
|
||||
pocketfft::c2r(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out, descriptor->scale());
|
||||
}
|
||||
break;
|
||||
case PocketFftType_R2C:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
float* a_in = reinterpret_cast<float*>(in[1]);
|
||||
std::complex<float>* a_out =
|
||||
reinterpret_cast<std::complex<float>*>(out);
|
||||
pocketfft::r2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out,
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
double* a_in = reinterpret_cast<double*>(in[1]);
|
||||
std::complex<double>* a_out =
|
||||
reinterpret_cast<std::complex<double>*>(out);
|
||||
pocketfft::r2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out, descriptor->scale());
|
||||
}
|
||||
break;
|
||||
void fixstrides(stride_t &str, size_t size) {
|
||||
ptrdiff_t ssize = ptrdiff_t(size);
|
||||
for (auto &s : str) {
|
||||
auto tmp = s / ssize;
|
||||
if (tmp * ssize != s)
|
||||
throw "Bad stride";
|
||||
s = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
void PocketFft(void *out, void **in, XlaCustomCallStatus *) {
|
||||
const PocketFftDescriptor *descriptor = GetPocketFftDescriptor(in[0]);
|
||||
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
|
||||
stride_t stride_in(descriptor->strides_in()->begin(),
|
||||
descriptor->strides_in()->end());
|
||||
stride_t stride_out(descriptor->strides_out()->begin(),
|
||||
descriptor->strides_out()->end());
|
||||
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
|
||||
|
||||
switch (descriptor->fft_type()) {
|
||||
case PocketFftType_C2C:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
fixstrides(stride_in, sizeof(std::complex<float>));
|
||||
fixstrides(stride_out, sizeof(std::complex<float>));
|
||||
ducc0::cfmav<std::complex<float>> m_in(
|
||||
reinterpret_cast<std::complex<float> *>(in[1]), shape, stride_in);
|
||||
ducc0::vfmav<std::complex<float>> m_out(
|
||||
reinterpret_cast<std::complex<float> *>(out), shape, stride_out);
|
||||
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
fixstrides(stride_in, sizeof(std::complex<double>));
|
||||
fixstrides(stride_out, sizeof(std::complex<double>));
|
||||
ducc0::cfmav<std::complex<double>> m_in(
|
||||
reinterpret_cast<std::complex<double> *>(in[1]), shape, stride_in);
|
||||
ducc0::vfmav<std::complex<double>> m_out(
|
||||
reinterpret_cast<std::complex<double> *>(out), shape, stride_out);
|
||||
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<double>(descriptor->scale()));
|
||||
}
|
||||
break;
|
||||
case PocketFftType_C2R:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
fixstrides(stride_in, sizeof(std::complex<float>));
|
||||
fixstrides(stride_out, sizeof(float));
|
||||
auto shape_in = shape;
|
||||
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<std::complex<float>> m_in(
|
||||
reinterpret_cast<std::complex<float> *>(in[1]), shape_in, stride_in);
|
||||
ducc0::vfmav<float> m_out(reinterpret_cast<float *>(out), shape,
|
||||
stride_out);
|
||||
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
fixstrides(stride_in, sizeof(std::complex<double>));
|
||||
fixstrides(stride_out, sizeof(double));
|
||||
auto shape_in = shape;
|
||||
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<std::complex<double>> m_in(
|
||||
reinterpret_cast<std::complex<double> *>(in[1]), shape_in, stride_in);
|
||||
ducc0::vfmav<double> m_out(reinterpret_cast<double *>(out), shape,
|
||||
stride_out);
|
||||
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<double>(descriptor->scale()));
|
||||
}
|
||||
break;
|
||||
case PocketFftType_R2C:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
fixstrides(stride_in, sizeof(float));
|
||||
fixstrides(stride_out, sizeof(std::complex<float>));
|
||||
auto shape_out = shape;
|
||||
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<float> m_in(reinterpret_cast<float *>(in[1]), shape,
|
||||
stride_in);
|
||||
ducc0::vfmav<std::complex<float>> m_out(
|
||||
reinterpret_cast<std::complex<float> *>(out), shape_out, stride_out);
|
||||
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
fixstrides(stride_in, sizeof(double));
|
||||
fixstrides(stride_out, sizeof(std::complex<double>));
|
||||
auto shape_out = shape;
|
||||
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<double> m_in(reinterpret_cast<double *>(in[1]), shape,
|
||||
stride_in);
|
||||
ducc0::vfmav<std::complex<double>> m_out(
|
||||
reinterpret_cast<std::complex<double> *>(out), shape_out, stride_out);
|
||||
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<double>(descriptor->scale()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
16
third_party/pocketfft/BUILD.bazel
vendored
16
third_party/pocketfft/BUILD.bazel
vendored
@ -4,8 +4,20 @@ package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = "pocketfft",
|
||||
hdrs = ["pocketfft_hdronly.h"],
|
||||
copts = ["-fexceptions"],
|
||||
hdrs = ["src/ducc0/fft/fft.h"],
|
||||
srcs = ["src/ducc0/fft/fft1d.h",
|
||||
"src/ducc0/infra/aligned_array.h",
|
||||
"src/ducc0/infra/error_handling.h",
|
||||
"src/ducc0/infra/mav.h",
|
||||
"src/ducc0/infra/simd.h",
|
||||
"src/ducc0/infra/threading.h",
|
||||
"src/ducc0/infra/threading.cc",
|
||||
"src/ducc0/infra/useful_macros.h",
|
||||
"src/ducc0/math/cmplx.h",
|
||||
"src/ducc0/math/unity_roots.h",
|
||||
],
|
||||
copts = ["-fexceptions", "-ffast-math"],
|
||||
features = ["-use_header_modules"],
|
||||
include_prefix = "pocketfft",
|
||||
includes = ["pocketfft", "src"],
|
||||
)
|
||||
|
8
third_party/pocketfft/workspace.bzl
vendored
8
third_party/pocketfft/workspace.bzl
vendored
@ -19,11 +19,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
def repo():
|
||||
http_archive(
|
||||
name = "pocketfft",
|
||||
sha256 = "66eda977b195965d27aeb9d74f46e0029a6a02e75fbbc47bb554aad68615a260",
|
||||
strip_prefix = "pocketfft-f800d91ba695b6e19ae2687dd60366900b928002",
|
||||
strip_prefix = "ducc-356d619a4b5f6f8940d15913c14a043355ef23be",
|
||||
sha256 = "d23eb2d06f03604867ad40af4fe92dec7cccc2c59f5119e9f01b35b973885c61,
|
||||
urls = [
|
||||
"https://github.com/mreineck/pocketfft/archive/f800d91ba695b6e19ae2687dd60366900b928002.tar.gz",
|
||||
"https://storage.googleapis.com/jax-releases/mirror/pocketfft/pocketfft-f800d91ba695b6e19ae2687dd60366900b928002.tar.gz",
|
||||
"https://github.com/mreineck/ducc/archive/356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz",
|
||||
"https://storage.googleapis.com/jax-releases/mirror/ducc/ducc-356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz",
|
||||
],
|
||||
build_file = "@//third_party/pocketfft:BUILD.bazel",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user