mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
This commit is contained in:
commit
498fd2083e
@ -27,8 +27,8 @@ http_archive(
|
||||
# path = "/path/to/tensorflow",
|
||||
# )
|
||||
|
||||
load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
|
||||
pocketfft()
|
||||
load("//third_party/ducc:workspace.bzl", ducc = "repo")
|
||||
ducc()
|
||||
|
||||
# Initialize TensorFlow's external dependencies.
|
||||
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
|
||||
|
@ -4332,8 +4332,8 @@ Copyright 2019 The TensorFlow Authors. All rights reserved.
|
||||
limitations under the License.
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
License for pocketfft:
|
||||
Copyright (C) 2010-2018 Max-Planck-Society
|
||||
License for the FFT components of ducc0:
|
||||
Copyright (C) 2010-2022 Max-Planck-Society
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
|
@ -171,8 +171,8 @@ def prepare_wheel(sources_path):
|
||||
copy_to_jaxlib("__main__/jaxlib/lapack.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/_pocketfft.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/pocketfft.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/_ducc_fft.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")
|
||||
|
@ -31,6 +31,9 @@ from jax.interpreters import batching
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib import mlir_api_version
|
||||
from jax._src.lib import xla_client
|
||||
# TODO(phawkins): remove pocketfft references when the minimum jaxlib version
|
||||
# is 0.3.17 or newer.
|
||||
from jax._src.lib import ducc_fft
|
||||
from jax._src.lib import pocketfft
|
||||
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
|
||||
|
||||
@ -123,8 +126,12 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
|
||||
|
||||
def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
|
||||
x_aval, = ctx.avals_in
|
||||
return [pocketfft.pocketfft_mhlo(x, x_aval.dtype, fft_type=fft_type,
|
||||
if ducc_fft:
|
||||
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
|
||||
fft_lengths=fft_lengths)]
|
||||
else:
|
||||
return [pocketfft.pocketfft_mhlo(x, x_aval.dtype, fft_type=fft_type,
|
||||
fft_lengths=fft_lengths)]
|
||||
|
||||
|
||||
def _naive_rfft(x, fft_lengths):
|
||||
@ -188,5 +195,4 @@ fft_p.def_abstract_eval(fft_abstract_eval)
|
||||
mlir.register_lowering(fft_p, _fft_lowering)
|
||||
ad.deflinear2(fft_p, _fft_transpose_rule)
|
||||
batching.primitive_batchers[fft_p] = _fft_batching_rule
|
||||
if pocketfft:
|
||||
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
|
||||
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
|
||||
|
@ -99,7 +99,17 @@ cpu_feature_guard.check_cpu_features()
|
||||
|
||||
import jaxlib.xla_client as xla_client
|
||||
import jaxlib.lapack as lapack
|
||||
import jaxlib.pocketfft as pocketfft
|
||||
|
||||
# TODO(phawkins): remove pocketfft references when the minimum jaxlib version
|
||||
# is 0.3.17 or newer.
|
||||
try:
|
||||
import jaxlib.pocketfft as pocketfft # pytype: disable=import-error
|
||||
except ImportError:
|
||||
pocketfft = None # type: ignore
|
||||
try:
|
||||
import jaxlib.ducc_fft as ducc_fft # pytype: disable=import-error
|
||||
except ImportError:
|
||||
ducc_fft = None # type: ignore
|
||||
|
||||
xla_extension = xla_client._xla
|
||||
pytree = xla_client._xla.pytree
|
||||
|
34
jaxlib/BUILD
34
jaxlib/BUILD
@ -29,6 +29,7 @@ package(default_visibility = ["//:__subpackages__"])
|
||||
py_library(
|
||||
name = "jaxlib",
|
||||
srcs = [
|
||||
"ducc_fft.py",
|
||||
"gpu_linalg.py",
|
||||
"gpu_prng.py",
|
||||
"gpu_solver.py",
|
||||
@ -36,14 +37,13 @@ py_library(
|
||||
"init.py",
|
||||
"lapack.py",
|
||||
"mhlo_helpers.py",
|
||||
"pocketfft.py",
|
||||
":version",
|
||||
":xla_client",
|
||||
],
|
||||
data = [":xla_extension"],
|
||||
deps = [
|
||||
":_ducc_fft",
|
||||
":_lapack",
|
||||
":_pocketfft",
|
||||
":cpu_feature_guard",
|
||||
"//jaxlib/mlir",
|
||||
"//jaxlib/mlir:builtin_dialect",
|
||||
@ -178,40 +178,40 @@ pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
# PocketFFT
|
||||
# DUCC (CPU FFTs)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "pocketfft_flatbuffers_cc",
|
||||
srcs = ["pocketfft.fbs"],
|
||||
name = "ducc_fft_flatbuffers_cc",
|
||||
srcs = ["ducc_fft.fbs"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pocketfft_kernels",
|
||||
srcs = ["pocketfft_kernels.cc"],
|
||||
hdrs = ["pocketfft_kernels.h"],
|
||||
copts = ["-fexceptions"], # PocketFFT may throw.
|
||||
name = "ducc_fft_kernels",
|
||||
srcs = ["ducc_fft_kernels.cc"],
|
||||
hdrs = ["ducc_fft_kernels.h"],
|
||||
copts = ["-fexceptions"], # DUCC may throw.
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":pocketfft_flatbuffers_cc",
|
||||
":ducc_fft_flatbuffers_cc",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@ducc",
|
||||
"@flatbuffers//:runtime_cc",
|
||||
"@pocketfft",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pocketfft",
|
||||
srcs = ["pocketfft.cc"],
|
||||
name = "_ducc_fft",
|
||||
srcs = ["ducc_fft.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pocketfft",
|
||||
module_name = "_ducc_fft",
|
||||
deps = [
|
||||
":ducc_fft_flatbuffers_cc",
|
||||
":ducc_fft_kernels",
|
||||
":kernel_pybind11_helpers",
|
||||
":pocketfft_flatbuffers_cc",
|
||||
":pocketfft_kernels",
|
||||
"@flatbuffers//:runtime_cc",
|
||||
"@pybind11",
|
||||
],
|
||||
@ -222,9 +222,9 @@ cc_library(
|
||||
srcs = ["cpu_kernels.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":ducc_fft_kernels",
|
||||
":lapack_kernels",
|
||||
":lapack_kernels_using_lapack",
|
||||
":pocketfft_kernels",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
// JAX-generated HLO code from outside of JAX.
|
||||
|
||||
#include "jaxlib/lapack_kernels.h"
|
||||
#include "jaxlib/pocketfft_kernels.h"
|
||||
#include "jaxlib/ducc_fft_kernels.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
|
||||
namespace jax {
|
||||
@ -105,7 +105,7 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
"lapack_cgees", ComplexGees<std::complex<float>>::Kernel, "Host");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
"lapack_zgees", ComplexGees<std::complex<double>>::Kernel, "Host");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("pocketfft", PocketFft, "Host");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("ducc_fft", DuccFft, "Host");
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
||||
|
@ -16,53 +16,53 @@ limitations under the License.
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "jaxlib/pocketfft_generated.h"
|
||||
#include "jaxlib/pocketfft_kernels.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/ducc_fft_generated.h"
|
||||
#include "jaxlib/ducc_fft_kernels.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
py::bytes BuildPocketFftDescriptor(const std::vector<uint64_t>& shape,
|
||||
bool is_double, int fft_type,
|
||||
const std::vector<uint64_t>& fft_lengths,
|
||||
const std::vector<uint64_t>& strides_in,
|
||||
const std::vector<uint64_t>& strides_out,
|
||||
const std::vector<uint32_t>& axes,
|
||||
bool forward, double scale) {
|
||||
PocketFftDescriptorT descriptor;
|
||||
py::bytes BuildDuccFftDescriptor(const std::vector<uint64_t> &shape,
|
||||
bool is_double, int fft_type,
|
||||
const std::vector<uint64_t> &fft_lengths,
|
||||
const std::vector<uint64_t> &strides_in,
|
||||
const std::vector<uint64_t> &strides_out,
|
||||
const std::vector<uint32_t> &axes,
|
||||
bool forward, double scale) {
|
||||
DuccFftDescriptorT descriptor;
|
||||
descriptor.shape = shape;
|
||||
descriptor.fft_type = static_cast<PocketFftType>(fft_type);
|
||||
descriptor.fft_type = static_cast<DuccFftType>(fft_type);
|
||||
descriptor.dtype =
|
||||
is_double ? PocketFftDtype_COMPLEX128 : PocketFftDtype_COMPLEX64;
|
||||
is_double ? DuccFftDtype_COMPLEX128 : DuccFftDtype_COMPLEX64;
|
||||
descriptor.strides_in = strides_in;
|
||||
descriptor.strides_out = strides_out;
|
||||
descriptor.axes = axes;
|
||||
descriptor.forward = forward;
|
||||
descriptor.scale = scale;
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.Finish(PocketFftDescriptor::Pack(fbb, &descriptor));
|
||||
return py::bytes(reinterpret_cast<char*>(fbb.GetBufferPointer()),
|
||||
fbb.Finish(DuccFftDescriptor::Pack(fbb, &descriptor));
|
||||
return py::bytes(reinterpret_cast<char *>(fbb.GetBufferPointer()),
|
||||
fbb.GetSize());
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
pybind11::dict dict;
|
||||
dict["pocketfft"] = EncapsulateFunction(PocketFft);
|
||||
dict["ducc_fft"] = EncapsulateFunction(DuccFft);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_pocketfft, m) {
|
||||
PYBIND11_MODULE(_ducc_fft, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("pocketfft_descriptor", &BuildPocketFftDescriptor, py::arg("shape"),
|
||||
m.def("ducc_fft_descriptor", &BuildDuccFftDescriptor, py::arg("shape"),
|
||||
py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"),
|
||||
py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"),
|
||||
py::arg("forward"), py::arg("scale"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
||||
} // namespace
|
||||
} // namespace jax
|
@ -15,20 +15,20 @@ limitations under the License.
|
||||
|
||||
namespace jax;
|
||||
|
||||
enum PocketFftDtype : byte {
|
||||
enum DuccFftDtype : byte {
|
||||
COMPLEX64 = 0,
|
||||
COMPLEX128 = 1,
|
||||
}
|
||||
|
||||
enum PocketFftType : byte {
|
||||
enum DuccFftType : byte {
|
||||
C2C = 0,
|
||||
C2R = 1,
|
||||
R2C = 2,
|
||||
}
|
||||
|
||||
table PocketFftDescriptor {
|
||||
dtype:PocketFftDtype;
|
||||
fft_type:PocketFftType;
|
||||
table DuccFftDescriptor {
|
||||
dtype:DuccFftDtype;
|
||||
fft_type:DuccFftType;
|
||||
shape:[uint64];
|
||||
strides_in:[uint64];
|
||||
strides_out:[uint64];
|
||||
@ -37,4 +37,4 @@ table PocketFftDescriptor {
|
||||
scale:double;
|
||||
}
|
||||
|
||||
root_type PocketFftDescriptor;
|
||||
root_type DuccFftDescriptor;
|
@ -19,12 +19,12 @@ import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
|
||||
|
||||
from .mhlo_helpers import custom_call
|
||||
from . import _pocketfft
|
||||
from . import _ducc_fft
|
||||
import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
for _name, _value in _pocketfft.registrations().items():
|
||||
for _name, _value in _ducc_fft.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="cpu")
|
||||
|
||||
FftType = xla_client.FftType
|
||||
@ -34,7 +34,7 @@ _C2C = 0
|
||||
_C2R = 1
|
||||
_R2C = 2
|
||||
|
||||
def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
fft_lengths: List[int]) -> bytes:
|
||||
n = len(shape)
|
||||
assert len(fft_lengths) >= 1
|
||||
@ -44,7 +44,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
forward = fft_type in (FftType.FFT, FftType.RFFT)
|
||||
is_double = np.finfo(dtype).dtype == np.float64
|
||||
if fft_type == FftType.RFFT:
|
||||
pocketfft_type = _R2C
|
||||
ducc_fft_type = _R2C
|
||||
|
||||
assert dtype in (np.float32, np.float64), dtype
|
||||
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
|
||||
@ -54,7 +54,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
out_shape[-1] = out_shape[-1] // 2 + 1
|
||||
|
||||
elif fft_type == FftType.IRFFT:
|
||||
pocketfft_type = _C2R
|
||||
ducc_fft_type = _C2R
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
|
||||
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
|
||||
@ -64,7 +64,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
out_shape[-1] = fft_lengths[-1]
|
||||
assert (out_shape[-1] // 2 + 1) == shape[-1]
|
||||
else:
|
||||
pocketfft_type = _C2C
|
||||
ducc_fft_type = _C2C
|
||||
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
out_dtype = dtype
|
||||
@ -79,13 +79,13 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
|
||||
# C++ kernel to describe the FFT to perform.
|
||||
strides_in = []
|
||||
stride = dtype.itemsize
|
||||
stride = 1
|
||||
for d in reversed(shape):
|
||||
strides_in.append(stride)
|
||||
stride *= d
|
||||
|
||||
strides_out = []
|
||||
stride = out_dtype.itemsize
|
||||
stride = 1
|
||||
for d in reversed(out_shape):
|
||||
strides_out.append(stride)
|
||||
stride *= d
|
||||
@ -93,10 +93,10 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))]
|
||||
|
||||
scale = 1. if forward else (1. / np.prod(fft_lengths))
|
||||
descriptor = _pocketfft.pocketfft_descriptor(
|
||||
descriptor = _ducc_fft.ducc_fft_descriptor(
|
||||
shape=shape if fft_type != FftType.IRFFT else out_shape,
|
||||
is_double=is_double,
|
||||
fft_type=pocketfft_type,
|
||||
fft_type=ducc_fft_type,
|
||||
fft_lengths=fft_lengths,
|
||||
strides_in=list(reversed(strides_in)),
|
||||
strides_out=list(reversed(strides_out)),
|
||||
@ -107,13 +107,13 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
return descriptor, out_dtype, out_shape
|
||||
|
||||
|
||||
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
"""PocketFFT kernel for CPU."""
|
||||
def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
"""DUCC FFT kernel for CPU."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
n = len(a_type.shape)
|
||||
|
||||
fft_lengths = list(fft_lengths)
|
||||
descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor(
|
||||
descriptor_bytes, out_dtype, out_shape = _ducc_fft_descriptor(
|
||||
list(a_type.shape), dtype, fft_type, fft_lengths)
|
||||
|
||||
if out_dtype == np.float32:
|
||||
@ -141,7 +141,7 @@ def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
|
||||
layout = tuple(range(n - 1, -1, -1))
|
||||
return custom_call(
|
||||
"pocketfft",
|
||||
"ducc_fft",
|
||||
[ir.RankedTensorType.get(out_shape, out_type)],
|
||||
[descriptor, a],
|
||||
operand_layouts=[[0], layout],
|
100
jaxlib/ducc_fft_kernels.cc
Normal file
100
jaxlib/ducc_fft_kernels.cc
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2020 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "jaxlib/ducc_fft_generated.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "ducc/src/ducc0/fft/fft.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
using shape_t = ducc0::fmav_info::shape_t;
|
||||
using stride_t = ducc0::fmav_info::stride_t;
|
||||
|
||||
void DuccFft(void *out, void **in, XlaCustomCallStatus *) {
|
||||
const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]);
|
||||
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
|
||||
stride_t stride_in(descriptor->strides_in()->begin(),
|
||||
descriptor->strides_in()->end());
|
||||
stride_t stride_out(descriptor->strides_out()->begin(),
|
||||
descriptor->strides_out()->end());
|
||||
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
|
||||
|
||||
switch (descriptor->fft_type()) {
|
||||
case DuccFftType_C2C:
|
||||
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
|
||||
ducc0::cfmav<std::complex<float>> m_in(
|
||||
reinterpret_cast<std::complex<float> *>(in[1]), shape, stride_in);
|
||||
ducc0::vfmav<std::complex<float>> m_out(
|
||||
reinterpret_cast<std::complex<float> *>(out), shape, stride_out);
|
||||
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
ducc0::cfmav<std::complex<double>> m_in(
|
||||
reinterpret_cast<std::complex<double> *>(in[1]), shape, stride_in);
|
||||
ducc0::vfmav<std::complex<double>> m_out(
|
||||
reinterpret_cast<std::complex<double> *>(out), shape, stride_out);
|
||||
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<double>(descriptor->scale()));
|
||||
}
|
||||
break;
|
||||
case DuccFftType_C2R:
|
||||
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
|
||||
auto shape_in = shape;
|
||||
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<std::complex<float>> m_in(
|
||||
reinterpret_cast<std::complex<float> *>(in[1]), shape_in, stride_in);
|
||||
ducc0::vfmav<float> m_out(reinterpret_cast<float *>(out), shape,
|
||||
stride_out);
|
||||
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
auto shape_in = shape;
|
||||
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<std::complex<double>> m_in(
|
||||
reinterpret_cast<std::complex<double> *>(in[1]), shape_in, stride_in);
|
||||
ducc0::vfmav<double> m_out(reinterpret_cast<double *>(out), shape,
|
||||
stride_out);
|
||||
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<double>(descriptor->scale()));
|
||||
}
|
||||
break;
|
||||
case DuccFftType_R2C:
|
||||
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
|
||||
auto shape_out = shape;
|
||||
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<float> m_in(reinterpret_cast<float *>(in[1]), shape,
|
||||
stride_in);
|
||||
ducc0::vfmav<std::complex<float>> m_out(
|
||||
reinterpret_cast<std::complex<float> *>(out), shape_out, stride_out);
|
||||
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
auto shape_out = shape;
|
||||
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
|
||||
ducc0::cfmav<double> m_in(reinterpret_cast<double *>(in[1]), shape,
|
||||
stride_in);
|
||||
ducc0::vfmav<std::complex<double>> m_out(
|
||||
reinterpret_cast<std::complex<double> *>(out), shape_out, stride_out);
|
||||
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
|
||||
static_cast<double>(descriptor->scale()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jax
|
@ -17,6 +17,6 @@ limitations under the License.
|
||||
|
||||
namespace jax {
|
||||
|
||||
void PocketFft(void* out, void** in, XlaCustomCallStatus*);
|
||||
void DuccFft(void* out, void** in, XlaCustomCallStatus*);
|
||||
|
||||
} // namespace jax
|
@ -1,90 +0,0 @@
|
||||
/* Copyright 2020 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "pocketfft/pocketfft_hdronly.h"
|
||||
#include "jaxlib/pocketfft_generated.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
void PocketFft(void* out, void** in, XlaCustomCallStatus*) {
|
||||
const PocketFftDescriptor* descriptor = GetPocketFftDescriptor(in[0]);
|
||||
pocketfft::shape_t shape(descriptor->shape()->begin(),
|
||||
descriptor->shape()->end());
|
||||
pocketfft::stride_t stride_in(descriptor->strides_in()->begin(),
|
||||
descriptor->strides_in()->end());
|
||||
pocketfft::stride_t stride_out(descriptor->strides_out()->begin(),
|
||||
descriptor->strides_out()->end());
|
||||
pocketfft::shape_t axes(descriptor->axes()->begin(),
|
||||
descriptor->axes()->end());
|
||||
|
||||
switch (descriptor->fft_type()) {
|
||||
case PocketFftType_C2C:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
std::complex<float>* a_in =
|
||||
reinterpret_cast<std::complex<float>*>(in[1]);
|
||||
std::complex<float>* a_out =
|
||||
reinterpret_cast<std::complex<float>*>(out);
|
||||
pocketfft::c2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out,
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
std::complex<double>* a_in =
|
||||
reinterpret_cast<std::complex<double>*>(in[1]);
|
||||
std::complex<double>* a_out =
|
||||
reinterpret_cast<std::complex<double>*>(out);
|
||||
pocketfft::c2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out, descriptor->scale());
|
||||
}
|
||||
break;
|
||||
case PocketFftType_C2R:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
std::complex<float>* a_in =
|
||||
reinterpret_cast<std::complex<float>*>(in[1]);
|
||||
float* a_out = reinterpret_cast<float*>(out);
|
||||
pocketfft::c2r(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out,
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
std::complex<double>* a_in =
|
||||
reinterpret_cast<std::complex<double>*>(in[1]);
|
||||
double* a_out = reinterpret_cast<double*>(out);
|
||||
pocketfft::c2r(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out, descriptor->scale());
|
||||
}
|
||||
break;
|
||||
case PocketFftType_R2C:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
float* a_in = reinterpret_cast<float*>(in[1]);
|
||||
std::complex<float>* a_out =
|
||||
reinterpret_cast<std::complex<float>*>(out);
|
||||
pocketfft::r2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out,
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
double* a_in = reinterpret_cast<double*>(in[1]);
|
||||
std::complex<double>* a_out =
|
||||
reinterpret_cast<std::complex<double>*>(out);
|
||||
pocketfft::r2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out, descriptor->scale());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jax
|
@ -665,7 +665,8 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
axes = range(ndims - fft_ndims, ndims)
|
||||
fft_lengths = tuple(shape[axis] for axis in axes)
|
||||
op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
|
||||
self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)
|
||||
self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng,
|
||||
rtol=1e-5)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}"
|
||||
|
29
third_party/ducc/BUILD.bazel
vendored
Normal file
29
third_party/ducc/BUILD.bazel
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = "ducc",
|
||||
srcs = [
|
||||
"src/ducc0/fft/fft1d.h",
|
||||
"src/ducc0/infra/aligned_array.h",
|
||||
"src/ducc0/infra/error_handling.h",
|
||||
"src/ducc0/infra/mav.h",
|
||||
"src/ducc0/infra/simd.h",
|
||||
"src/ducc0/infra/threading.cc",
|
||||
"src/ducc0/infra/threading.h",
|
||||
"src/ducc0/infra/useful_macros.h",
|
||||
"src/ducc0/math/cmplx.h",
|
||||
"src/ducc0/math/unity_roots.h",
|
||||
],
|
||||
hdrs = ["src/ducc0/fft/fft.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-ffast-math",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
include_prefix = "ducc",
|
||||
includes = [
|
||||
"src",
|
||||
],
|
||||
)
|
@ -12,18 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Bazel workspace for PocketFFT."""
|
||||
"""Bazel workspace for DUCC (CPU FFTs)."""
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
def repo():
|
||||
http_archive(
|
||||
name = "pocketfft",
|
||||
sha256 = "66eda977b195965d27aeb9d74f46e0029a6a02e75fbbc47bb554aad68615a260",
|
||||
strip_prefix = "pocketfft-f800d91ba695b6e19ae2687dd60366900b928002",
|
||||
name = "ducc",
|
||||
strip_prefix = "ducc-356d619a4b5f6f8940d15913c14a043355ef23be",
|
||||
sha256 = "d23eb2d06f03604867ad40af4fe92dec7cccc2c59f5119e9f01b35b973885c61",
|
||||
urls = [
|
||||
"https://github.com/mreineck/pocketfft/archive/f800d91ba695b6e19ae2687dd60366900b928002.tar.gz",
|
||||
"https://storage.googleapis.com/jax-releases/mirror/pocketfft/pocketfft-f800d91ba695b6e19ae2687dd60366900b928002.tar.gz",
|
||||
"https://github.com/mreineck/ducc/archive/356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz",
|
||||
"https://storage.googleapis.com/jax-releases/mirror/ducc/ducc-356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz",
|
||||
],
|
||||
build_file = "@//third_party/pocketfft:BUILD.bazel",
|
||||
build_file = "@//third_party/ducc:BUILD.bazel",
|
||||
)
|
11
third_party/pocketfft/BUILD.bazel
vendored
11
third_party/pocketfft/BUILD.bazel
vendored
@ -1,11 +0,0 @@
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = "pocketfft",
|
||||
hdrs = ["pocketfft_hdronly.h"],
|
||||
copts = ["-fexceptions"],
|
||||
features = ["-use_header_modules"],
|
||||
include_prefix = "pocketfft",
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user