[JAX] Use PocketFFT for FFTs on CPU instead of Eigen.

PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
This commit is contained in:
Peter Hawkins 2020-10-23 14:20:06 -07:00 committed by jax authors
parent 8121255d7b
commit f58f1ee456
16 changed files with 395 additions and 8 deletions

View File

@ -50,6 +50,9 @@ tf_workspace(
tf_bind()
load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
pocketfft()
# Required for TensorFlow dependency on @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")

View File

@ -28,6 +28,8 @@ sh_binary(
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
"//jaxlib",
"//jaxlib:lapack.so",
"//jaxlib:_pocketfft.so",
"//jaxlib:pocketfft_flatbuffers_py",
] + if_cuda([
"//jaxlib:cublas_kernels",
"//jaxlib:cusolver_kernels",
@ -35,4 +37,3 @@ sh_binary(
]),
deps = ["@bazel_tools//tools/bash/runfiles"],
)

View File

@ -53,6 +53,9 @@ fi
# Copy the XLA dependencies into jax/lib, fixing up some imports to point to the
# new location.
cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/_pocketfft.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/pocketfft_flatbuffers_py_generated.py)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/pocketfft.py)" "${TARGET}/jaxlib"
if [[ -x "$(rlocation __main__/jaxlib/cusolver_kernels.so)" ]]; then
cp -f "$(rlocation __main__/jaxlib/cublas_kernels.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/cusolver_kernels.so)" "${TARGET}/jaxlib"

View File

@ -35,7 +35,7 @@ setup(
author_email='jax-dev@google.com',
packages=['jaxlib'],
python_requires='>=3.6',
install_requires=['scipy', 'numpy>=1.12', 'absl-py'],
install_requires=['scipy', 'numpy>=1.12', 'absl-py', 'flatbuffers'],
url='https://github.com/google/jax',
license='Apache-2.0',
package_data={'jaxlib': binary_libs},

View File

@ -44,7 +44,10 @@ pytype_library(
],
),
srcs_version = "PY3",
deps = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"],
deps = [
"//third_party/py/jax/jaxlib:_pocketfft",
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
],
)
pytype_library(

View File

@ -27,6 +27,7 @@ from jax import lib
from jax.lib import xla_client
from jax.interpreters import ad
from jax.interpreters import batching
from jax.lib import pocketfft
xops = xla_client.ops
@ -146,3 +147,5 @@ fft_p.def_abstract_eval(fft_abstract_eval)
xla.translations[fft_p] = fft_translation_rule
ad.deflinear(fft_p, fft_transpose_rule)
batching.primitive_batchers[fft_p] = fft_batching_rule
if pocketfft:
xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft

View File

@ -144,8 +144,12 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_fft(self, harness: primitive_harness.Harness):
if len(harness.params["fft_lengths"]) > 3:
with self.assertRaisesRegex(RuntimeError, "FFT only supports ranks 1-3"):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
if jtu.device_under_test() == "gpu":
with self.assertRaisesRegex(RuntimeError,
"FFT only supports ranks 1-3"):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
else:
raise unittest.SkipTest("TF does not support >3D FFTs.")
elif (jtu.device_under_test() == "tpu" and
len(harness.params["fft_lengths"]) > 1):
# TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX
@ -154,7 +158,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
else:
tol = None
if jtu.device_under_test() == "gpu":
if jtu.device_under_test() in ("cpu", "gpu"):
if harness.params["dtype"] in jtu.dtypes.boolean:
tol = 0.01
else:

View File

@ -67,3 +67,10 @@ try:
from jaxlib import tpu_client # pytype: disable=import-error
except:
tpu_client = None
# TODO(phawkins): Make this import unconditional once the minimum jaxlib version
# is 0.1.57 or greater.
try:
from jaxlib import pocketfft # pytype: disable=import-error
except:
pocketfft = None

View File

@ -17,6 +17,7 @@
load("@org_tensorflow//tensorflow/core/platform/default:build_config.bzl", "pyx_library")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_py_library")
licenses(["notice"])
@ -32,6 +33,7 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":kernel_helpers",
"@com_google_absl//absl/base",
"@pybind11",
],
)
@ -76,8 +78,10 @@ py_library(
srcs = [
"cuda_prng.py",
"cusolver.py",
"pocketfft.py",
"version.py",
],
deps = [":pocketfft_flatbuffers_py"],
)
py_library(
@ -174,3 +178,32 @@ pybind_extension(
"@pybind11",
],
)
flatbuffer_cc_library(
name = "pocketfft_flatbuffers_cc",
srcs = ["pocketfft.fbs"],
)
flatbuffer_py_library(
name = "pocketfft_flatbuffers_py",
srcs = ["pocketfft.fbs"],
)
pybind_extension(
name = "_pocketfft",
srcs = ["pocketfft.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_pocketfft",
deps = [
":kernel_pybind11_helpers",
":pocketfft_flatbuffers_cc",
"@com_google_absl//absl/strings",
"@flatbuffers//:runtime_cc",
"@pocketfft",
"@pybind11",
],
)

View File

@ -16,8 +16,9 @@ limitations under the License.
#ifndef JAXLIB_KERNEL_PYBIND11_HELPERS_H_
#define JAXLIB_KERNEL_PYBIND11_HELPERS_H_
#include "include/pybind11/pybind11.h"
#include "absl/base/casts.h"
#include "jaxlib/kernel_helpers.h"
#include "include/pybind11/pybind11.h"
namespace jax {
@ -36,7 +37,8 @@ pybind11::bytes PackDescriptor(const T& descriptor) {
template <typename T>
pybind11::capsule EncapsulateFunction(T* fn) {
return pybind11::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
return pybind11::capsule(absl::bit_cast<void*>(fn),
"xla._CUSTOM_CALL_TARGET");
}
} // namespace jax

101
jaxlib/pocketfft.cc Normal file
View File

@ -0,0 +1,101 @@
/* Copyright 2020 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include "flatbuffers/flatbuffers.h"
#include "pocketfft/pocketfft_hdronly.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/pocketfft_generated.h"
#include "include/pybind11/pybind11.h"
namespace jax {
namespace {
void PocketFft(void* out, void** in) {
const PocketFftDescriptor* descriptor = GetPocketFftDescriptor(in[0]);
pocketfft::shape_t shape(descriptor->shape()->begin(),
descriptor->shape()->end());
pocketfft::stride_t stride_in(descriptor->strides_in()->begin(),
descriptor->strides_in()->end());
pocketfft::stride_t stride_out(descriptor->strides_out()->begin(),
descriptor->strides_out()->end());
pocketfft::shape_t axes(descriptor->axes()->begin(),
descriptor->axes()->end());
switch (descriptor->fft_type()) {
case PocketFftType_C2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
std::complex<float>* a_in =
reinterpret_cast<std::complex<float>*>(in[1]);
std::complex<float>* a_out =
reinterpret_cast<std::complex<float>*>(out);
pocketfft::c2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
std::complex<double>* a_in =
reinterpret_cast<std::complex<double>*>(in[1]);
std::complex<double>* a_out =
reinterpret_cast<std::complex<double>*>(out);
pocketfft::c2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
case PocketFftType_C2R:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
std::complex<float>* a_in =
reinterpret_cast<std::complex<float>*>(in[1]);
float* a_out = reinterpret_cast<float*>(out);
pocketfft::c2r(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
std::complex<double>* a_in =
reinterpret_cast<std::complex<double>*>(in[1]);
double* a_out = reinterpret_cast<double*>(out);
pocketfft::c2r(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
case PocketFftType_R2C:
if (descriptor->dtype() == PocketFftDtype_COMPLEX64) {
float* a_in = reinterpret_cast<float*>(in[1]);
std::complex<float>* a_out =
reinterpret_cast<std::complex<float>*>(out);
pocketfft::r2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out,
static_cast<float>(descriptor->scale()));
} else {
double* a_in = reinterpret_cast<double*>(in[1]);
std::complex<double>* a_out =
reinterpret_cast<std::complex<double>*>(out);
pocketfft::r2c(shape, stride_in, stride_out, axes,
descriptor->forward(), a_in, a_out, descriptor->scale());
}
break;
}
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["pocketfft"] = EncapsulateFunction(PocketFft);
return dict;
}
PYBIND11_MODULE(_pocketfft, m) { m.def("registrations", &Registrations); }
} // namespace
} // namespace jax

40
jaxlib/pocketfft.fbs Normal file
View File

@ -0,0 +1,40 @@
/* Copyright 2020 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
namespace jax;
enum PocketFftDtype : byte {
COMPLEX64 = 0,
COMPLEX128 = 1,
}
enum PocketFftType : byte {
C2C = 0,
C2R = 1,
R2C = 2,
}
table PocketFftDescriptor {
dtype:PocketFftDtype;
fft_type:PocketFftType;
shape:[uint64];
strides_in:[uint64];
strides_out:[uint64];
axes:[uint32];
forward:bool;
scale:double;
}
root_type PocketFftDescriptor;

142
jaxlib/pocketfft.py Normal file
View File

@ -0,0 +1,142 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from jaxlib import _pocketfft
from jaxlib import pocketfft_flatbuffers_py_generated as pd
import numpy as np
import flatbuffers
from jaxlib import xla_client
for _name, _value in _pocketfft.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")
FftType = xla_client.FftType
def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]):
"""PocketFFT kernel for CPU."""
shape = c.get_shape(a)
n = len(shape.dimensions())
dtype = shape.element_type()
builder = flatbuffers.Builder(128)
fft_lengths = list(fft_lengths)
assert len(fft_lengths) >= 1
assert len(fft_lengths) <= n, (fft_lengths, n)
forward = fft_type in (FftType.FFT, FftType.RFFT)
if fft_type == FftType.RFFT:
pocketfft_type = pd.PocketFftType.R2C
assert dtype in (np.float32, np.float64), dtype
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
pocketfft_dtype = (
pd.PocketFftDtype.COMPLEX64
if dtype == np.float32 else pd.PocketFftDtype.COMPLEX128)
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
shape, fft_lengths)
out_shape = list(shape.dimensions())
out_shape[-1] = out_shape[-1] // 2 + 1
elif fft_type == FftType.IRFFT:
pocketfft_type = pd.PocketFftType.C2R
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
pocketfft_dtype = (
pd.PocketFftDtype.COMPLEX64
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
assert list(shape.dimensions())[-len(fft_lengths):-1] == fft_lengths[:-1]
out_shape = list(shape.dimensions())
out_shape[-1] = fft_lengths[-1]
assert (out_shape[-1] // 2 + 1) == shape.dimensions()[-1]
else:
pocketfft_type = pd.PocketFftType.C2C
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = dtype
pocketfft_dtype = (
pd.PocketFftDtype.COMPLEX64
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
shape, fft_lengths)
out_shape = shape.dimensions()
# PocketFft does not allow size 0 dimensions.
if 0 in shape.dimensions() or 0 in out_shape:
return xla_client.ops.Broadcast(
xla_client.ops.Constant(c, np.array(0, dtype=out_dtype)), out_shape)
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
# C++ kernel to describe the FFT to perform.
pd.PocketFftDescriptorStartShapeVector(builder, n)
for d in reversed(
shape.dimensions() if fft_type != FftType.IRFFT else out_shape):
builder.PrependUint64(d)
pocketfft_shape = builder.EndVector(n)
pd.PocketFftDescriptorStartStridesInVector(builder, n)
stride = dtype.itemsize
for d in reversed(shape.dimensions()):
builder.PrependUint64(stride)
stride *= d
strides_in = builder.EndVector(n)
pd.PocketFftDescriptorStartStridesOutVector(builder, n)
stride = out_dtype.itemsize
for d in reversed(out_shape):
builder.PrependUint64(stride)
stride *= d
strides_out = builder.EndVector(n)
pd.PocketFftDescriptorStartAxesVector(builder, len(fft_lengths))
for d in range(len(fft_lengths)):
builder.PrependUint32(n - d - 1)
axes = builder.EndVector(len(fft_lengths))
scale = 1. if forward else (1. / np.prod(fft_lengths))
pd.PocketFftDescriptorStart(builder)
pd.PocketFftDescriptorAddDtype(builder, pocketfft_dtype)
pd.PocketFftDescriptorAddFftType(builder, pocketfft_type)
pd.PocketFftDescriptorAddShape(builder, pocketfft_shape)
pd.PocketFftDescriptorAddStridesIn(builder, strides_in)
pd.PocketFftDescriptorAddStridesOut(builder, strides_out)
pd.PocketFftDescriptorAddAxes(builder, axes)
pd.PocketFftDescriptorAddForward(builder, forward)
pd.PocketFftDescriptorAddScale(builder, scale)
descriptor = pd.PocketFftDescriptorEnd(builder)
builder.Finish(descriptor)
descriptor_bytes = builder.Output()
return xla_client.ops.CustomCallWithLayout(
c,
b"pocketfft",
operands=(
xla_client.ops.Constant(
c, np.frombuffer(descriptor_bytes, dtype=np.uint8)),
a,
),
shape_with_layout=xla_client.Shape.array_shape(
out_dtype, out_shape, tuple(range(n - 1, -1, -1))),
operand_shapes_with_layout=(
xla_client.Shape.array_shape(
np.dtype(np.uint8), (len(descriptor_bytes),), (0,)),
xla_client.Shape.array_shape(dtype, shape.dimensions(),
tuple(range(n - 1, -1, -1))),
))

View File

@ -147,6 +147,10 @@ class FftTest(jtu.JaxTestCase):
self.assertRaises(
ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3]))
def testFftEmpty(self):
out = jnp.fft.fft(jnp.zeros((0,), jnp.complex64)).block_until_ready()
self.assertArraysEqual(jnp.zeros((0,), jnp.complex64), out)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}_hermitian={}_shape={}_axis={}".format(
inverse, real, hermitian, jtu.format_shape_dtype_string(shape, dtype), axis),

11
third_party/pocketfft/BUILD.bazel vendored Normal file
View File

@ -0,0 +1,11 @@
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "pocketfft",
hdrs = ["pocketfft_hdronly.h"],
copts = ["-fexceptions"],
features = ["-use_header_modules"],
include_prefix = "pocketfft",
)

30
third_party/pocketfft/workspace.bzl vendored Normal file
View File

@ -0,0 +1,30 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bazel workspace for PocketFFT."""
load("@org_tensorflow//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "pocketfft",
sha256 = "bba6962b9f71a220b4873549bad5e6e5a2630bc465e3f9a9822c4ab2418709a7",
strip_prefix = "pocketfft-53e9dd4d12f986207c96d97c5183f5a72239c76e",
urls = [
"https://gitlab.mpcdf.mpg.de/mtr/pocketfft/-/archive/53e9dd4d12f986207c96d97c5183f5a72239c76e/pocketfft-53e9dd4d12f986207c96d97c5183f5a72239c76e.tar.gz",
# Repeat the URL to silence the Tensorflow third_party_http_archive mirror check.
"https://gitlab.mpcdf.mpg.de/mtr/pocketfft/-/archive/53e9dd4d12f986207c96d97c5183f5a72239c76e/pocketfft-53e9dd4d12f986207c96d97c5183f5a72239c76e.tar.gz",
],
build_file = "@//third_party/pocketfft:BUILD.bazel",
)