Merge pull request #12122 from hawkinsp:fft

PiperOrigin-RevId: 470294824
This commit is contained in:
jax authors 2022-08-26 11:32:07 -07:00
commit 498fd2083e
17 changed files with 224 additions and 179 deletions

View File

@ -27,8 +27,8 @@ http_archive(
# path = "/path/to/tensorflow",
# )
load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
pocketfft()
load("//third_party/ducc:workspace.bzl", ducc = "repo")
ducc()
# Initialize TensorFlow's external dependencies.
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")

View File

@ -4332,8 +4332,8 @@ Copyright 2019 The TensorFlow Authors. All rights reserved.
limitations under the License.
--------------------------------------------------------------------------------
License for pocketfft:
Copyright (C) 2010-2018 Max-Planck-Society
License for the FFT components of ducc0:
Copyright (C) 2010-2022 Max-Planck-Society
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,

View File

@ -171,8 +171,8 @@ def prepare_wheel(sources_path):
copy_to_jaxlib("__main__/jaxlib/lapack.py")
copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}")
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
copy_to_jaxlib(f"__main__/jaxlib/_pocketfft.{pyext}")
copy_to_jaxlib("__main__/jaxlib/pocketfft.py")
copy_to_jaxlib(f"__main__/jaxlib/_ducc_fft.{pyext}")
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")

View File

@ -31,6 +31,9 @@ from jax.interpreters import batching
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import mlir_api_version
from jax._src.lib import xla_client
# TODO(phawkins): remove pocketfft references when the minimum jaxlib version
# is 0.3.17 or newer.
from jax._src.lib import ducc_fft
from jax._src.lib import pocketfft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
@ -123,8 +126,12 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
x_aval, = ctx.avals_in
return [pocketfft.pocketfft_mhlo(x, x_aval.dtype, fft_type=fft_type,
if ducc_fft:
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]
else:
return [pocketfft.pocketfft_mhlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]
def _naive_rfft(x, fft_lengths):
@ -188,5 +195,4 @@ fft_p.def_abstract_eval(fft_abstract_eval)
mlir.register_lowering(fft_p, _fft_lowering)
ad.deflinear2(fft_p, _fft_transpose_rule)
batching.primitive_batchers[fft_p] = _fft_batching_rule
if pocketfft:
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')

View File

@ -99,7 +99,17 @@ cpu_feature_guard.check_cpu_features()
import jaxlib.xla_client as xla_client
import jaxlib.lapack as lapack
import jaxlib.pocketfft as pocketfft
# TODO(phawkins): remove pocketfft references when the minimum jaxlib version
# is 0.3.17 or newer.
try:
import jaxlib.pocketfft as pocketfft # pytype: disable=import-error
except ImportError:
pocketfft = None # type: ignore
try:
import jaxlib.ducc_fft as ducc_fft # pytype: disable=import-error
except ImportError:
ducc_fft = None # type: ignore
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree

View File

@ -29,6 +29,7 @@ package(default_visibility = ["//:__subpackages__"])
py_library(
name = "jaxlib",
srcs = [
"ducc_fft.py",
"gpu_linalg.py",
"gpu_prng.py",
"gpu_solver.py",
@ -36,14 +37,13 @@ py_library(
"init.py",
"lapack.py",
"mhlo_helpers.py",
"pocketfft.py",
":version",
":xla_client",
],
data = [":xla_extension"],
deps = [
":_ducc_fft",
":_lapack",
":_pocketfft",
":cpu_feature_guard",
"//jaxlib/mlir",
"//jaxlib/mlir:builtin_dialect",
@ -178,40 +178,40 @@ pybind_extension(
],
)
# PocketFFT
# DUCC (CPU FFTs)
flatbuffer_cc_library(
name = "pocketfft_flatbuffers_cc",
srcs = ["pocketfft.fbs"],
name = "ducc_fft_flatbuffers_cc",
srcs = ["ducc_fft.fbs"],
)
cc_library(
name = "pocketfft_kernels",
srcs = ["pocketfft_kernels.cc"],
hdrs = ["pocketfft_kernels.h"],
copts = ["-fexceptions"], # PocketFFT may throw.
name = "ducc_fft_kernels",
srcs = ["ducc_fft_kernels.cc"],
hdrs = ["ducc_fft_kernels.h"],
copts = ["-fexceptions"], # DUCC may throw.
features = ["-use_header_modules"],
deps = [
":pocketfft_flatbuffers_cc",
":ducc_fft_flatbuffers_cc",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@ducc",
"@flatbuffers//:runtime_cc",
"@pocketfft",
],
)
pybind_extension(
name = "_pocketfft",
srcs = ["pocketfft.cc"],
name = "_ducc_fft",
srcs = ["ducc_fft.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_pocketfft",
module_name = "_ducc_fft",
deps = [
":ducc_fft_flatbuffers_cc",
":ducc_fft_kernels",
":kernel_pybind11_helpers",
":pocketfft_flatbuffers_cc",
":pocketfft_kernels",
"@flatbuffers//:runtime_cc",
"@pybind11",
],
@ -222,9 +222,9 @@ cc_library(
srcs = ["cpu_kernels.cc"],
visibility = ["//visibility:public"],
deps = [
":ducc_fft_kernels",
":lapack_kernels",
":lapack_kernels_using_lapack",
":pocketfft_kernels",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
],
alwayslink = 1,

View File

@ -17,7 +17,7 @@ limitations under the License.
// JAX-generated HLO code from outside of JAX.
#include "jaxlib/lapack_kernels.h"
#include "jaxlib/pocketfft_kernels.h"
#include "jaxlib/ducc_fft_kernels.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
namespace jax {
@ -105,7 +105,7 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"lapack_cgees", ComplexGees<std::complex<float>>::Kernel, "Host");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"lapack_zgees", ComplexGees<std::complex<double>>::Kernel, "Host");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("pocketfft", PocketFft, "Host");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("ducc_fft", DuccFft, "Host");
} // namespace
} // namespace jax

View File

@ -16,53 +16,53 @@ limitations under the License.
#include <complex>
#include <vector>
#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/pocketfft_generated.h"
#include "jaxlib/pocketfft_kernels.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/ducc_fft_generated.h"
#include "jaxlib/ducc_fft_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
namespace py = pybind11;
namespace jax {
namespace {
py::bytes BuildPocketFftDescriptor(const std::vector<uint64_t>& shape,
bool is_double, int fft_type,
const std::vector<uint64_t>& fft_lengths,
const std::vector<uint64_t>& strides_in,
const std::vector<uint64_t>& strides_out,
const std::vector<uint32_t>& axes,
bool forward, double scale) {
PocketFftDescriptorT descriptor;
py::bytes BuildDuccFftDescriptor(const std::vector<uint64_t> &shape,
bool is_double, int fft_type,
const std::vector<uint64_t> &fft_lengths,
const std::vector<uint64_t> &strides_in,
const std::vector<uint64_t> &strides_out,
const std::vector<uint32_t> &axes,
bool forward, double scale) {
DuccFftDescriptorT descriptor;
descriptor.shape = shape;
descriptor.fft_type = static_cast<PocketFftType>(fft_type);
descriptor.fft_type = static_cast<DuccFftType>(fft_type);
descriptor.dtype =
is_double ? PocketFftDtype_COMPLEX128 : PocketFftDtype_COMPLEX64;
is_double ? DuccFftDtype_COMPLEX128 : DuccFftDtype_COMPLEX64;
descriptor.strides_in = strides_in;
descriptor.strides_out = strides_out;
descriptor.axes = axes;
descriptor.forward = forward;
descriptor.scale = scale;
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(PocketFftDescriptor::Pack(fbb, &descriptor));
return py::bytes(reinterpret_cast<char*>(fbb.GetBufferPointer()),
fbb.Finish(DuccFftDescriptor::Pack(fbb, &descriptor));
return py::bytes(reinterpret_cast<char *>(fbb.GetBufferPointer()),
fbb.GetSize());
}
py::dict Registrations() {
pybind11::dict dict;
dict["pocketfft"] = EncapsulateFunction(PocketFft);
dict["ducc_fft"] = EncapsulateFunction(DuccFft);
return dict;
}
PYBIND11_MODULE(_pocketfft, m) {
PYBIND11_MODULE(_ducc_fft, m) {
m.def("registrations", &Registrations);
m.def("pocketfft_descriptor", &BuildPocketFftDescriptor, py::arg("shape"),
m.def("ducc_fft_descriptor", &BuildDuccFftDescriptor, py::arg("shape"),
py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"),
py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"),
py::arg("forward"), py::arg("scale"));
}
} // namespace
} // namespace jax
} // namespace
} // namespace jax

View File

@ -15,20 +15,20 @@ limitations under the License.
namespace jax;
enum PocketFftDtype : byte {
enum DuccFftDtype : byte {
COMPLEX64 = 0,
COMPLEX128 = 1,
}
enum PocketFftType : byte {
enum DuccFftType : byte {
C2C = 0,
C2R = 1,
R2C = 2,
}
table PocketFftDescriptor {
dtype:PocketFftDtype;
fft_type:PocketFftType;
table DuccFftDescriptor {
dtype:DuccFftDtype;
fft_type:DuccFftType;
shape:[uint64];
strides_in:[uint64];
strides_out:[uint64];
@ -37,4 +37,4 @@ table PocketFftDescriptor {
scale:double;
}
root_type PocketFftDescriptor;
root_type DuccFftDescriptor;

View File

@ -19,12 +19,12 @@ import jaxlib.mlir.dialects.mhlo as mhlo
from .mhlo_helpers import custom_call
from . import _pocketfft
from . import _ducc_fft
import numpy as np
from jaxlib import xla_client
for _name, _value in _pocketfft.registrations().items():
for _name, _value in _ducc_fft.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")
FftType = xla_client.FftType
@ -34,7 +34,7 @@ _C2C = 0
_C2R = 1
_R2C = 2
def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType,
fft_lengths: List[int]) -> bytes:
n = len(shape)
assert len(fft_lengths) >= 1
@ -44,7 +44,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
forward = fft_type in (FftType.FFT, FftType.RFFT)
is_double = np.finfo(dtype).dtype == np.float64
if fft_type == FftType.RFFT:
pocketfft_type = _R2C
ducc_fft_type = _R2C
assert dtype in (np.float32, np.float64), dtype
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
@ -54,7 +54,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
out_shape[-1] = out_shape[-1] // 2 + 1
elif fft_type == FftType.IRFFT:
pocketfft_type = _C2R
ducc_fft_type = _C2R
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
@ -64,7 +64,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
out_shape[-1] = fft_lengths[-1]
assert (out_shape[-1] // 2 + 1) == shape[-1]
else:
pocketfft_type = _C2C
ducc_fft_type = _C2C
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = dtype
@ -79,13 +79,13 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
# C++ kernel to describe the FFT to perform.
strides_in = []
stride = dtype.itemsize
stride = 1
for d in reversed(shape):
strides_in.append(stride)
stride *= d
strides_out = []
stride = out_dtype.itemsize
stride = 1
for d in reversed(out_shape):
strides_out.append(stride)
stride *= d
@ -93,10 +93,10 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))]
scale = 1. if forward else (1. / np.prod(fft_lengths))
descriptor = _pocketfft.pocketfft_descriptor(
descriptor = _ducc_fft.ducc_fft_descriptor(
shape=shape if fft_type != FftType.IRFFT else out_shape,
is_double=is_double,
fft_type=pocketfft_type,
fft_type=ducc_fft_type,
fft_lengths=fft_lengths,
strides_in=list(reversed(strides_in)),
strides_out=list(reversed(strides_out)),
@ -107,13 +107,13 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
return descriptor, out_dtype, out_shape
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
"""PocketFFT kernel for CPU."""
def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
"""DUCC FFT kernel for CPU."""
a_type = ir.RankedTensorType(a.type)
n = len(a_type.shape)
fft_lengths = list(fft_lengths)
descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor(
descriptor_bytes, out_dtype, out_shape = _ducc_fft_descriptor(
list(a_type.shape), dtype, fft_type, fft_lengths)
if out_dtype == np.float32:
@ -141,7 +141,7 @@ def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
layout = tuple(range(n - 1, -1, -1))
return custom_call(
"pocketfft",
"ducc_fft",
[ir.RankedTensorType.get(out_shape, out_type)],
[descriptor, a],
operand_layouts=[[0], layout],

100
jaxlib/ducc_fft_kernels.cc Normal file
View File

@ -0,0 +1,100 @@
/* Copyright 2020 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include "flatbuffers/flatbuffers.h"
#include "jaxlib/ducc_fft_generated.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#include "ducc/src/ducc0/fft/fft.h"
namespace jax {
using shape_t = ducc0::fmav_info::shape_t;
using stride_t = ducc0::fmav_info::stride_t;
void DuccFft(void *out, void **in, XlaCustomCallStatus *) {
const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]);
shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end());
stride_t stride_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
stride_t stride_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end());
switch (descriptor->fft_type()) {
case DuccFftType_C2C:
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
ducc0::cfmav<std::complex<float>> m_in(
reinterpret_cast<std::complex<float> *>(in[1]), shape, stride_in);
ducc0::vfmav<std::complex<float>> m_out(
reinterpret_cast<std::complex<float> *>(out), shape, stride_out);
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
} else {
ducc0::cfmav<std::complex<double>> m_in(
reinterpret_cast<std::complex<double> *>(in[1]), shape, stride_in);
ducc0::vfmav<std::complex<double>> m_out(
reinterpret_cast<std::complex<double> *>(out), shape, stride_out);
ducc0::c2c(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
}
break;
case DuccFftType_C2R:
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
auto shape_in = shape;
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
ducc0::cfmav<std::complex<float>> m_in(
reinterpret_cast<std::complex<float> *>(in[1]), shape_in, stride_in);
ducc0::vfmav<float> m_out(reinterpret_cast<float *>(out), shape,
stride_out);
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
} else {
auto shape_in = shape;
shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1;
ducc0::cfmav<std::complex<double>> m_in(
reinterpret_cast<std::complex<double> *>(in[1]), shape_in, stride_in);
ducc0::vfmav<double> m_out(reinterpret_cast<double *>(out), shape,
stride_out);
ducc0::c2r(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
}
break;
case DuccFftType_R2C:
if (descriptor->dtype() == DuccFftDtype_COMPLEX64) {
auto shape_out = shape;
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
ducc0::cfmav<float> m_in(reinterpret_cast<float *>(in[1]), shape,
stride_in);
ducc0::vfmav<std::complex<float>> m_out(
reinterpret_cast<std::complex<float> *>(out), shape_out, stride_out);
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
static_cast<float>(descriptor->scale()));
} else {
auto shape_out = shape;
shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1;
ducc0::cfmav<double> m_in(reinterpret_cast<double *>(in[1]), shape,
stride_in);
ducc0::vfmav<std::complex<double>> m_out(
reinterpret_cast<std::complex<double> *>(out), shape_out, stride_out);
ducc0::r2c(m_in, m_out, axes, descriptor->forward(),
static_cast<double>(descriptor->scale()));
}
break;
}
}
} // namespace jax

View File

@ -17,6 +17,6 @@ limitations under the License.
namespace jax {
void PocketFft(void* out, void** in, XlaCustomCallStatus*);
void DuccFft(void* out, void** in, XlaCustomCallStatus*);
} // namespace jax

View File

@ -1,90 +0,0 @@
/* Copyright 2020 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include "flatbuffers/flatbuffers.h"
#include "pocketfft/pocketfft_hdronly.h"
#include "jaxlib/pocketfft_generated.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
void PocketFft(void* out, void** in, XlaCustomCallStatus*) {
const PocketFftDescriptor* descriptor = GetPocketFftDescriptor(in[0]);
pocketfft::shape_t shape(descriptor->shape()->begin(),
descriptor->shape()->end());
pocketfft::stride_t stride_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
pocketfft::stride_t stride_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
pocketfft::shape_t axes(descriptor->axes()->begin(),
descriptor->axes()->end());
switch (descriptor->fft_type()) {
case PocketFftType_C2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
std::complex<float>* a_in =
reinterpret_cast<std::complex<float>*>(in[1]);
std::complex<float>* a_out =
reinterpret_cast<std::complex<float>*>(out);
pocketfft::c2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
std::complex<double>* a_in =
reinterpret_cast<std::complex<double>*>(in[1]);
std::complex<double>* a_out =
reinterpret_cast<std::complex<double>*>(out);
pocketfft::c2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
case PocketFftType_C2R:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
std::complex<float>* a_in =
reinterpret_cast<std::complex<float>*>(in[1]);
float* a_out = reinterpret_cast<float*>(out);
pocketfft::c2r(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
std::complex<double>* a_in =
reinterpret_cast<std::complex<double>*>(in[1]);
double* a_out = reinterpret_cast<double*>(out);
pocketfft::c2r(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
case PocketFftType_R2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
float* a_in = reinterpret_cast<float*>(in[1]);
std::complex<float>* a_out =
reinterpret_cast<std::complex<float>*>(out);
pocketfft::r2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
double* a_in = reinterpret_cast<double*>(in[1]);
std::complex<double>* a_out =
reinterpret_cast<std::complex<double>*>(out);
pocketfft::r2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
}
}
} // namespace jax

View File

@ -665,7 +665,8 @@ class LaxVmapTest(jtu.JaxTestCase):
axes = range(ndims - fft_ndims, ndims)
fft_lengths = tuple(shape[axis] for axis in axes)
op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)
self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng,
rtol=1e-5)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}"

29
third_party/ducc/BUILD.bazel vendored Normal file
View File

@ -0,0 +1,29 @@
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "ducc",
srcs = [
"src/ducc0/fft/fft1d.h",
"src/ducc0/infra/aligned_array.h",
"src/ducc0/infra/error_handling.h",
"src/ducc0/infra/mav.h",
"src/ducc0/infra/simd.h",
"src/ducc0/infra/threading.cc",
"src/ducc0/infra/threading.h",
"src/ducc0/infra/useful_macros.h",
"src/ducc0/math/cmplx.h",
"src/ducc0/math/unity_roots.h",
],
hdrs = ["src/ducc0/fft/fft.h"],
copts = [
"-fexceptions",
"-ffast-math",
],
features = ["-use_header_modules"],
include_prefix = "ducc",
includes = [
"src",
],
)

View File

@ -12,18 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bazel workspace for PocketFFT."""
"""Bazel workspace for DUCC (CPU FFTs)."""
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
def repo():
http_archive(
name = "pocketfft",
sha256 = "66eda977b195965d27aeb9d74f46e0029a6a02e75fbbc47bb554aad68615a260",
strip_prefix = "pocketfft-f800d91ba695b6e19ae2687dd60366900b928002",
name = "ducc",
strip_prefix = "ducc-356d619a4b5f6f8940d15913c14a043355ef23be",
sha256 = "d23eb2d06f03604867ad40af4fe92dec7cccc2c59f5119e9f01b35b973885c61",
urls = [
"https://github.com/mreineck/pocketfft/archive/f800d91ba695b6e19ae2687dd60366900b928002.tar.gz",
"https://storage.googleapis.com/jax-releases/mirror/pocketfft/pocketfft-f800d91ba695b6e19ae2687dd60366900b928002.tar.gz",
"https://github.com/mreineck/ducc/archive/356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz",
"https://storage.googleapis.com/jax-releases/mirror/ducc/ducc-356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz",
],
build_file = "@//third_party/pocketfft:BUILD.bazel",
build_file = "@//third_party/ducc:BUILD.bazel",
)

View File

@ -1,11 +0,0 @@
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "pocketfft",
hdrs = ["pocketfft_hdronly.h"],
copts = ["-fexceptions"],
features = ["-use_header_modules"],
include_prefix = "pocketfft",
)