rocm_jax/jaxlib/pocketfft.py
Peter Hawkins f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00

143 lines
5.1 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from jaxlib import _pocketfft
from jaxlib import pocketfft_flatbuffers_py_generated as pd
import numpy as np
import flatbuffers
from jaxlib import xla_client
for _name, _value in _pocketfft.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")
FftType = xla_client.FftType
def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]):
"""PocketFFT kernel for CPU."""
shape = c.get_shape(a)
n = len(shape.dimensions())
dtype = shape.element_type()
builder = flatbuffers.Builder(128)
fft_lengths = list(fft_lengths)
assert len(fft_lengths) >= 1
assert len(fft_lengths) <= n, (fft_lengths, n)
forward = fft_type in (FftType.FFT, FftType.RFFT)
if fft_type == FftType.RFFT:
pocketfft_type = pd.PocketFftType.R2C
assert dtype in (np.float32, np.float64), dtype
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
pocketfft_dtype = (
pd.PocketFftDtype.COMPLEX64
if dtype == np.float32 else pd.PocketFftDtype.COMPLEX128)
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
shape, fft_lengths)
out_shape = list(shape.dimensions())
out_shape[-1] = out_shape[-1] // 2 + 1
elif fft_type == FftType.IRFFT:
pocketfft_type = pd.PocketFftType.C2R
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
pocketfft_dtype = (
pd.PocketFftDtype.COMPLEX64
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
assert list(shape.dimensions())[-len(fft_lengths):-1] == fft_lengths[:-1]
out_shape = list(shape.dimensions())
out_shape[-1] = fft_lengths[-1]
assert (out_shape[-1] // 2 + 1) == shape.dimensions()[-1]
else:
pocketfft_type = pd.PocketFftType.C2C
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = dtype
pocketfft_dtype = (
pd.PocketFftDtype.COMPLEX64
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
shape, fft_lengths)
out_shape = shape.dimensions()
# PocketFft does not allow size 0 dimensions.
if 0 in shape.dimensions() or 0 in out_shape:
return xla_client.ops.Broadcast(
xla_client.ops.Constant(c, np.array(0, dtype=out_dtype)), out_shape)
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
# C++ kernel to describe the FFT to perform.
pd.PocketFftDescriptorStartShapeVector(builder, n)
for d in reversed(
shape.dimensions() if fft_type != FftType.IRFFT else out_shape):
builder.PrependUint64(d)
pocketfft_shape = builder.EndVector(n)
pd.PocketFftDescriptorStartStridesInVector(builder, n)
stride = dtype.itemsize
for d in reversed(shape.dimensions()):
builder.PrependUint64(stride)
stride *= d
strides_in = builder.EndVector(n)
pd.PocketFftDescriptorStartStridesOutVector(builder, n)
stride = out_dtype.itemsize
for d in reversed(out_shape):
builder.PrependUint64(stride)
stride *= d
strides_out = builder.EndVector(n)
pd.PocketFftDescriptorStartAxesVector(builder, len(fft_lengths))
for d in range(len(fft_lengths)):
builder.PrependUint32(n - d - 1)
axes = builder.EndVector(len(fft_lengths))
scale = 1. if forward else (1. / np.prod(fft_lengths))
pd.PocketFftDescriptorStart(builder)
pd.PocketFftDescriptorAddDtype(builder, pocketfft_dtype)
pd.PocketFftDescriptorAddFftType(builder, pocketfft_type)
pd.PocketFftDescriptorAddShape(builder, pocketfft_shape)
pd.PocketFftDescriptorAddStridesIn(builder, strides_in)
pd.PocketFftDescriptorAddStridesOut(builder, strides_out)
pd.PocketFftDescriptorAddAxes(builder, axes)
pd.PocketFftDescriptorAddForward(builder, forward)
pd.PocketFftDescriptorAddScale(builder, scale)
descriptor = pd.PocketFftDescriptorEnd(builder)
builder.Finish(descriptor)
descriptor_bytes = builder.Output()
return xla_client.ops.CustomCallWithLayout(
c,
b"pocketfft",
operands=(
xla_client.ops.Constant(
c, np.frombuffer(descriptor_bytes, dtype=np.uint8)),
a,
),
shape_with_layout=xla_client.Shape.array_shape(
out_dtype, out_shape, tuple(range(n - 1, -1, -1))),
operand_shapes_with_layout=(
xla_client.Shape.array_shape(
np.dtype(np.uint8), (len(descriptor_bytes),), (0,)),
xla_client.Shape.array_shape(dtype, shape.dimensions(),
tuple(range(n - 1, -1, -1))),
))