rocm_jax/jaxlib/pocketfft.cc
Peter Hawkins f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00

102 lines
4.0 KiB
C++

/* Copyright 2020 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include "flatbuffers/flatbuffers.h"
#include "pocketfft/pocketfft_hdronly.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/pocketfft_generated.h"
#include "include/pybind11/pybind11.h"
namespace jax {
namespace {
void PocketFft(void* out, void** in) {
const PocketFftDescriptor* descriptor = GetPocketFftDescriptor(in[0]);
pocketfft::shape_t shape(descriptor->shape()->begin(),
descriptor->shape()->end());
pocketfft::stride_t stride_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
pocketfft::stride_t stride_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
pocketfft::shape_t axes(descriptor->axes()->begin(),
descriptor->axes()->end());
switch (descriptor->fft_type()) {
case PocketFftType_C2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
std::complex<float>* a_in =
reinterpret_cast<std::complex<float>*>(in[1]);
std::complex<float>* a_out =
reinterpret_cast<std::complex<float>*>(out);
pocketfft::c2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
std::complex<double>* a_in =
reinterpret_cast<std::complex<double>*>(in[1]);
std::complex<double>* a_out =
reinterpret_cast<std::complex<double>*>(out);
pocketfft::c2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
case PocketFftType_C2R:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
std::complex<float>* a_in =
reinterpret_cast<std::complex<float>*>(in[1]);
float* a_out = reinterpret_cast<float*>(out);
pocketfft::c2r(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
std::complex<double>* a_in =
reinterpret_cast<std::complex<double>*>(in[1]);
double* a_out = reinterpret_cast<double*>(out);
pocketfft::c2r(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
case PocketFftType_R2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
float* a_in = reinterpret_cast<float*>(in[1]);
std::complex<float>* a_out =
reinterpret_cast<std::complex<float>*>(out);
pocketfft::r2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
double* a_in = reinterpret_cast<double*>(in[1]);
std::complex<double>* a_out =
reinterpret_cast<std::complex<double>*>(out);
pocketfft::r2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
}
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["pocketfft"] = EncapsulateFunction(PocketFft);
return dict;
}
PYBIND11_MODULE(_pocketfft, m) { m.def("registrations", &Registrations); }
} // namespace
} // namespace jax