mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.) For the benchmark in #2952 on my workstation: Before: ``` 907.3490574884647 max: 4.362646594533903e-08 mean: 6.237288307614869e-09 min: 0.0 numpy fft execution time [ms]: 37.088446617126465 jax fft execution time [ms]: 74.93342399597168 ``` After: ``` 907.3490574884647 max: 1.9057386696477137e-12 mean: 3.9326737908882566e-13 min: 0.0 numpy fft execution time [ms]: 37.756404876708984 jax fft execution time [ms]: 28.128278255462646 ``` Fixes https://github.com/google/jax/issues/2952 PiperOrigin-RevId: 338743753
This commit is contained in:
parent
8121255d7b
commit
f58f1ee456
@ -50,6 +50,9 @@ tf_workspace(
|
||||
|
||||
tf_bind()
|
||||
|
||||
load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
|
||||
pocketfft()
|
||||
|
||||
# Required for TensorFlow dependency on @com_github_grpc_grpc
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
@ -28,6 +28,8 @@ sh_binary(
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
|
||||
"//jaxlib",
|
||||
"//jaxlib:lapack.so",
|
||||
"//jaxlib:_pocketfft.so",
|
||||
"//jaxlib:pocketfft_flatbuffers_py",
|
||||
] + if_cuda([
|
||||
"//jaxlib:cublas_kernels",
|
||||
"//jaxlib:cusolver_kernels",
|
||||
@ -35,4 +37,3 @@ sh_binary(
|
||||
]),
|
||||
deps = ["@bazel_tools//tools/bash/runfiles"],
|
||||
)
|
||||
|
||||
|
@ -53,6 +53,9 @@ fi
|
||||
# Copy the XLA dependencies into jax/lib, fixing up some imports to point to the
|
||||
# new location.
|
||||
cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/_pocketfft.so)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/pocketfft_flatbuffers_py_generated.py)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/pocketfft.py)" "${TARGET}/jaxlib"
|
||||
if [[ -x "$(rlocation __main__/jaxlib/cusolver_kernels.so)" ]]; then
|
||||
cp -f "$(rlocation __main__/jaxlib/cublas_kernels.so)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/cusolver_kernels.so)" "${TARGET}/jaxlib"
|
||||
|
@ -35,7 +35,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib'],
|
||||
python_requires='>=3.6',
|
||||
install_requires=['scipy', 'numpy>=1.12', 'absl-py'],
|
||||
install_requires=['scipy', 'numpy>=1.12', 'absl-py', 'flatbuffers'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
package_data={'jaxlib': binary_libs},
|
||||
|
@ -44,7 +44,10 @@ pytype_library(
|
||||
],
|
||||
),
|
||||
srcs_version = "PY3",
|
||||
deps = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"],
|
||||
deps = [
|
||||
"//third_party/py/jax/jaxlib:_pocketfft",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
|
@ -27,6 +27,7 @@ from jax import lib
|
||||
from jax.lib import xla_client
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.lib import pocketfft
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
@ -146,3 +147,5 @@ fft_p.def_abstract_eval(fft_abstract_eval)
|
||||
xla.translations[fft_p] = fft_translation_rule
|
||||
ad.deflinear(fft_p, fft_transpose_rule)
|
||||
batching.primitive_batchers[fft_p] = fft_batching_rule
|
||||
if pocketfft:
|
||||
xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
|
||||
|
@ -144,8 +144,12 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_fft(self, harness: primitive_harness.Harness):
|
||||
if len(harness.params["fft_lengths"]) > 3:
|
||||
with self.assertRaisesRegex(RuntimeError, "FFT only supports ranks 1-3"):
|
||||
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
|
||||
if jtu.device_under_test() == "gpu":
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"FFT only supports ranks 1-3"):
|
||||
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
|
||||
else:
|
||||
raise unittest.SkipTest("TF does not support >3D FFTs.")
|
||||
elif (jtu.device_under_test() == "tpu" and
|
||||
len(harness.params["fft_lengths"]) > 1):
|
||||
# TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX
|
||||
@ -154,7 +158,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
|
||||
else:
|
||||
tol = None
|
||||
if jtu.device_under_test() == "gpu":
|
||||
if jtu.device_under_test() in ("cpu", "gpu"):
|
||||
if harness.params["dtype"] in jtu.dtypes.boolean:
|
||||
tol = 0.01
|
||||
else:
|
||||
|
@ -67,3 +67,10 @@ try:
|
||||
from jaxlib import tpu_client # pytype: disable=import-error
|
||||
except:
|
||||
tpu_client = None
|
||||
|
||||
# TODO(phawkins): Make this import unconditional once the minimum jaxlib version
|
||||
# is 0.1.57 or greater.
|
||||
try:
|
||||
from jaxlib import pocketfft # pytype: disable=import-error
|
||||
except:
|
||||
pocketfft = None
|
||||
|
33
jaxlib/BUILD
33
jaxlib/BUILD
@ -17,6 +17,7 @@
|
||||
load("@org_tensorflow//tensorflow/core/platform/default:build_config.bzl", "pyx_library")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
|
||||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_py_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
@ -32,6 +33,7 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":kernel_helpers",
|
||||
"@com_google_absl//absl/base",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
@ -76,8 +78,10 @@ py_library(
|
||||
srcs = [
|
||||
"cuda_prng.py",
|
||||
"cusolver.py",
|
||||
"pocketfft.py",
|
||||
"version.py",
|
||||
],
|
||||
deps = [":pocketfft_flatbuffers_py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -174,3 +178,32 @@ pybind_extension(
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "pocketfft_flatbuffers_cc",
|
||||
srcs = ["pocketfft.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "pocketfft_flatbuffers_py",
|
||||
srcs = ["pocketfft.fbs"],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pocketfft",
|
||||
srcs = ["pocketfft.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pocketfft",
|
||||
deps = [
|
||||
":kernel_pybind11_helpers",
|
||||
":pocketfft_flatbuffers_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers//:runtime_cc",
|
||||
"@pocketfft",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
@ -16,8 +16,9 @@ limitations under the License.
|
||||
#ifndef JAXLIB_KERNEL_PYBIND11_HELPERS_H_
|
||||
#define JAXLIB_KERNEL_PYBIND11_HELPERS_H_
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "absl/base/casts.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
@ -36,7 +37,8 @@ pybind11::bytes PackDescriptor(const T& descriptor) {
|
||||
|
||||
template <typename T>
|
||||
pybind11::capsule EncapsulateFunction(T* fn) {
|
||||
return pybind11::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
|
||||
return pybind11::capsule(absl::bit_cast<void*>(fn),
|
||||
"xla._CUSTOM_CALL_TARGET");
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
101
jaxlib/pocketfft.cc
Normal file
101
jaxlib/pocketfft.cc
Normal file
@ -0,0 +1,101 @@
|
||||
/* Copyright 2020 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "pocketfft/pocketfft_hdronly.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "jaxlib/pocketfft_generated.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
void PocketFft(void* out, void** in) {
|
||||
const PocketFftDescriptor* descriptor = GetPocketFftDescriptor(in[0]);
|
||||
pocketfft::shape_t shape(descriptor->shape()->begin(),
|
||||
descriptor->shape()->end());
|
||||
pocketfft::stride_t stride_in(descriptor->strides_in()->begin(),
|
||||
descriptor->strides_in()->end());
|
||||
pocketfft::stride_t stride_out(descriptor->strides_out()->begin(),
|
||||
descriptor->strides_out()->end());
|
||||
pocketfft::shape_t axes(descriptor->axes()->begin(),
|
||||
descriptor->axes()->end());
|
||||
|
||||
switch (descriptor->fft_type()) {
|
||||
case PocketFftType_C2C:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
std::complex<float>* a_in =
|
||||
reinterpret_cast<std::complex<float>*>(in[1]);
|
||||
std::complex<float>* a_out =
|
||||
reinterpret_cast<std::complex<float>*>(out);
|
||||
pocketfft::c2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out,
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
std::complex<double>* a_in =
|
||||
reinterpret_cast<std::complex<double>*>(in[1]);
|
||||
std::complex<double>* a_out =
|
||||
reinterpret_cast<std::complex<double>*>(out);
|
||||
pocketfft::c2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out, descriptor->scale());
|
||||
}
|
||||
break;
|
||||
case PocketFftType_C2R:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
std::complex<float>* a_in =
|
||||
reinterpret_cast<std::complex<float>*>(in[1]);
|
||||
float* a_out = reinterpret_cast<float*>(out);
|
||||
pocketfft::c2r(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out,
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
std::complex<double>* a_in =
|
||||
reinterpret_cast<std::complex<double>*>(in[1]);
|
||||
double* a_out = reinterpret_cast<double*>(out);
|
||||
pocketfft::c2r(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out, descriptor->scale());
|
||||
}
|
||||
break;
|
||||
case PocketFftType_R2C:
|
||||
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
|
||||
float* a_in = reinterpret_cast<float*>(in[1]);
|
||||
std::complex<float>* a_out =
|
||||
reinterpret_cast<std::complex<float>*>(out);
|
||||
pocketfft::r2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out,
|
||||
static_cast<float>(descriptor->scale()));
|
||||
} else {
|
||||
double* a_in = reinterpret_cast<double*>(in[1]);
|
||||
std::complex<double>* a_out =
|
||||
reinterpret_cast<std::complex<double>*>(out);
|
||||
pocketfft::r2c(shape, stride_in, stride_out, axes,
|
||||
descriptor->forward(), a_in, a_out, descriptor->scale());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
pybind11::dict Registrations() {
|
||||
pybind11::dict dict;
|
||||
dict["pocketfft"] = EncapsulateFunction(PocketFft);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_pocketfft, m) { m.def("registrations", &Registrations); }
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
40
jaxlib/pocketfft.fbs
Normal file
40
jaxlib/pocketfft.fbs
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
namespace jax;
|
||||
|
||||
enum PocketFftDtype : byte {
|
||||
COMPLEX64 = 0,
|
||||
COMPLEX128 = 1,
|
||||
}
|
||||
|
||||
enum PocketFftType : byte {
|
||||
C2C = 0,
|
||||
C2R = 1,
|
||||
R2C = 2,
|
||||
}
|
||||
|
||||
table PocketFftDescriptor {
|
||||
dtype:PocketFftDtype;
|
||||
fft_type:PocketFftType;
|
||||
shape:[uint64];
|
||||
strides_in:[uint64];
|
||||
strides_out:[uint64];
|
||||
axes:[uint32];
|
||||
forward:bool;
|
||||
scale:double;
|
||||
}
|
||||
|
||||
root_type PocketFftDescriptor;
|
142
jaxlib/pocketfft.py
Normal file
142
jaxlib/pocketfft.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
from jaxlib import _pocketfft
|
||||
from jaxlib import pocketfft_flatbuffers_py_generated as pd
|
||||
import numpy as np
|
||||
|
||||
import flatbuffers
|
||||
from jaxlib import xla_client
|
||||
|
||||
for _name, _value in _pocketfft.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="cpu")
|
||||
|
||||
FftType = xla_client.FftType
|
||||
|
||||
|
||||
def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
"""PocketFFT kernel for CPU."""
|
||||
shape = c.get_shape(a)
|
||||
n = len(shape.dimensions())
|
||||
dtype = shape.element_type()
|
||||
builder = flatbuffers.Builder(128)
|
||||
|
||||
fft_lengths = list(fft_lengths)
|
||||
assert len(fft_lengths) >= 1
|
||||
assert len(fft_lengths) <= n, (fft_lengths, n)
|
||||
|
||||
forward = fft_type in (FftType.FFT, FftType.RFFT)
|
||||
if fft_type == FftType.RFFT:
|
||||
pocketfft_type = pd.PocketFftType.R2C
|
||||
|
||||
assert dtype in (np.float32, np.float64), dtype
|
||||
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
|
||||
pocketfft_dtype = (
|
||||
pd.PocketFftDtype.COMPLEX64
|
||||
if dtype == np.float32 else pd.PocketFftDtype.COMPLEX128)
|
||||
|
||||
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
|
||||
shape, fft_lengths)
|
||||
out_shape = list(shape.dimensions())
|
||||
out_shape[-1] = out_shape[-1] // 2 + 1
|
||||
|
||||
elif fft_type == FftType.IRFFT:
|
||||
pocketfft_type = pd.PocketFftType.C2R
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
|
||||
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
|
||||
pocketfft_dtype = (
|
||||
pd.PocketFftDtype.COMPLEX64
|
||||
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
|
||||
|
||||
assert list(shape.dimensions())[-len(fft_lengths):-1] == fft_lengths[:-1]
|
||||
out_shape = list(shape.dimensions())
|
||||
out_shape[-1] = fft_lengths[-1]
|
||||
assert (out_shape[-1] // 2 + 1) == shape.dimensions()[-1]
|
||||
else:
|
||||
pocketfft_type = pd.PocketFftType.C2C
|
||||
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
out_dtype = dtype
|
||||
pocketfft_dtype = (
|
||||
pd.PocketFftDtype.COMPLEX64
|
||||
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
|
||||
|
||||
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
|
||||
shape, fft_lengths)
|
||||
out_shape = shape.dimensions()
|
||||
|
||||
# PocketFft does not allow size 0 dimensions.
|
||||
if 0 in shape.dimensions() or 0 in out_shape:
|
||||
return xla_client.ops.Broadcast(
|
||||
xla_client.ops.Constant(c, np.array(0, dtype=out_dtype)), out_shape)
|
||||
|
||||
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
|
||||
# C++ kernel to describe the FFT to perform.
|
||||
pd.PocketFftDescriptorStartShapeVector(builder, n)
|
||||
for d in reversed(
|
||||
shape.dimensions() if fft_type != FftType.IRFFT else out_shape):
|
||||
builder.PrependUint64(d)
|
||||
pocketfft_shape = builder.EndVector(n)
|
||||
|
||||
pd.PocketFftDescriptorStartStridesInVector(builder, n)
|
||||
stride = dtype.itemsize
|
||||
for d in reversed(shape.dimensions()):
|
||||
builder.PrependUint64(stride)
|
||||
stride *= d
|
||||
strides_in = builder.EndVector(n)
|
||||
pd.PocketFftDescriptorStartStridesOutVector(builder, n)
|
||||
stride = out_dtype.itemsize
|
||||
for d in reversed(out_shape):
|
||||
builder.PrependUint64(stride)
|
||||
stride *= d
|
||||
strides_out = builder.EndVector(n)
|
||||
|
||||
pd.PocketFftDescriptorStartAxesVector(builder, len(fft_lengths))
|
||||
for d in range(len(fft_lengths)):
|
||||
builder.PrependUint32(n - d - 1)
|
||||
axes = builder.EndVector(len(fft_lengths))
|
||||
|
||||
scale = 1. if forward else (1. / np.prod(fft_lengths))
|
||||
pd.PocketFftDescriptorStart(builder)
|
||||
pd.PocketFftDescriptorAddDtype(builder, pocketfft_dtype)
|
||||
pd.PocketFftDescriptorAddFftType(builder, pocketfft_type)
|
||||
pd.PocketFftDescriptorAddShape(builder, pocketfft_shape)
|
||||
pd.PocketFftDescriptorAddStridesIn(builder, strides_in)
|
||||
pd.PocketFftDescriptorAddStridesOut(builder, strides_out)
|
||||
pd.PocketFftDescriptorAddAxes(builder, axes)
|
||||
pd.PocketFftDescriptorAddForward(builder, forward)
|
||||
pd.PocketFftDescriptorAddScale(builder, scale)
|
||||
descriptor = pd.PocketFftDescriptorEnd(builder)
|
||||
builder.Finish(descriptor)
|
||||
descriptor_bytes = builder.Output()
|
||||
|
||||
return xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"pocketfft",
|
||||
operands=(
|
||||
xla_client.ops.Constant(
|
||||
c, np.frombuffer(descriptor_bytes, dtype=np.uint8)),
|
||||
a,
|
||||
),
|
||||
shape_with_layout=xla_client.Shape.array_shape(
|
||||
out_dtype, out_shape, tuple(range(n - 1, -1, -1))),
|
||||
operand_shapes_with_layout=(
|
||||
xla_client.Shape.array_shape(
|
||||
np.dtype(np.uint8), (len(descriptor_bytes),), (0,)),
|
||||
xla_client.Shape.array_shape(dtype, shape.dimensions(),
|
||||
tuple(range(n - 1, -1, -1))),
|
||||
))
|
@ -147,6 +147,10 @@ class FftTest(jtu.JaxTestCase):
|
||||
self.assertRaises(
|
||||
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3]))
|
||||
|
||||
def testFftEmpty(self):
|
||||
out = jnp.fft.fft(jnp.zeros((0,), jnp.complex64)).block_until_ready()
|
||||
self.assertArraysEqual(jnp.zeros((0,), jnp.complex64), out)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}_hermitian={}_shape={}_axis={}".format(
|
||||
inverse, real, hermitian, jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
|
11
third_party/pocketfft/BUILD.bazel
vendored
Normal file
11
third_party/pocketfft/BUILD.bazel
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = "pocketfft",
|
||||
hdrs = ["pocketfft_hdronly.h"],
|
||||
copts = ["-fexceptions"],
|
||||
features = ["-use_header_modules"],
|
||||
include_prefix = "pocketfft",
|
||||
)
|
30
third_party/pocketfft/workspace.bzl
vendored
Normal file
30
third_party/pocketfft/workspace.bzl
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Bazel workspace for PocketFFT."""
|
||||
|
||||
load("@org_tensorflow//third_party:repo.bzl", "third_party_http_archive")
|
||||
|
||||
def repo():
|
||||
third_party_http_archive(
|
||||
name = "pocketfft",
|
||||
sha256 = "bba6962b9f71a220b4873549bad5e6e5a2630bc465e3f9a9822c4ab2418709a7",
|
||||
strip_prefix = "pocketfft-53e9dd4d12f986207c96d97c5183f5a72239c76e",
|
||||
urls = [
|
||||
"https://gitlab.mpcdf.mpg.de/mtr/pocketfft/-/archive/53e9dd4d12f986207c96d97c5183f5a72239c76e/pocketfft-53e9dd4d12f986207c96d97c5183f5a72239c76e.tar.gz",
|
||||
# Repeat the URL to silence the Tensorflow third_party_http_archive mirror check.
|
||||
"https://gitlab.mpcdf.mpg.de/mtr/pocketfft/-/archive/53e9dd4d12f986207c96d97c5183f5a72239c76e/pocketfft-53e9dd4d12f986207c96d97c5183f5a72239c76e.tar.gz",
|
||||
],
|
||||
build_file = "@//third_party/pocketfft:BUILD.bazel",
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user