1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 14:46:07 +00:00
rocm_jax/jaxlib/cpu/ducc_fft_kernels.cc
jax authors 42ef649e65 Merge pull request from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00

101 lines
4.3 KiB
C++

/* Copyright 2020 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include "ducc/src/ducc0/fft/fft.h"
#include "flatbuffers/flatbuffers.h"
#include "jaxlib/cpu/ducc_fft_generated.h"
#include "xla/service/custom_call_status.h"
namespace jax {
using shape_t = ducc0::fmav_info::shape_t;
using stride_t = ducc0::fmav_info::stride_t;
void DuccFft(void *out, void **in, XlaCustomCallStatus *) {
const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]);
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
stride_t stride_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
stride_t stride_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
switch (descriptor->fft_type()) {
case DuccFftType_C2C:
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
ducc0::cfmav<std::complex<float>> m_in(
reinterpret_cast<std::complex<float> *>(in[1]), shape, stride_in);
ducc0::vfmav<std::complex<float>> m_out(
reinterpret_cast<std::complex<float> *>(out), shape, stride_out);
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
} else {
ducc0::cfmav<std::complex<double>> m_in(
reinterpret_cast<std::complex<double> *>(in[1]), shape, stride_in);
ducc0::vfmav<std::complex<double>> m_out(
reinterpret_cast<std::complex<double> *>(out), shape, stride_out);
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
}
break;
case DuccFftType_C2R:
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
auto shape_in = shape;
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
ducc0::cfmav<std::complex<float>> m_in(
reinterpret_cast<std::complex<float> *>(in[1]), shape_in, stride_in);
ducc0::vfmav<float> m_out(reinterpret_cast<float *>(out), shape,
stride_out);
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
} else {
auto shape_in = shape;
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
ducc0::cfmav<std::complex<double>> m_in(
reinterpret_cast<std::complex<double> *>(in[1]), shape_in, stride_in);
ducc0::vfmav<double> m_out(reinterpret_cast<double *>(out), shape,
stride_out);
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
}
break;
case DuccFftType_R2C:
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
auto shape_out = shape;
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
ducc0::cfmav<float> m_in(reinterpret_cast<float *>(in[1]), shape,
stride_in);
ducc0::vfmav<std::complex<float>> m_out(
reinterpret_cast<std::complex<float> *>(out), shape_out, stride_out);
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
} else {
auto shape_out = shape;
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
ducc0::cfmav<double> m_in(reinterpret_cast<double *>(in[1]), shape,
stride_in);
ducc0::vfmav<std::complex<double>> m_out(
reinterpret_cast<std::complex<double> *>(out), shape_out, stride_out);
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
}
break;
}
}
} // namespace jax