mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 14:56:05 +00:00
157 lines
5.4 KiB
Python
157 lines
5.4 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import List
|
|
|
|
from . import _pocketfft
|
|
from . import pocketfft_flatbuffers_py_generated as pd
|
|
import numpy as np
|
|
|
|
import flatbuffers
|
|
from jaxlib import xla_client
|
|
|
|
for _name, _value in _pocketfft.registrations().items():
|
|
xla_client.register_custom_call_target(_name, _value, platform="cpu")
|
|
|
|
FftType = xla_client.FftType
|
|
|
|
flatbuffers_version_2 = hasattr(flatbuffers, "__version__")
|
|
|
|
|
|
def pocketfft(c, a, *, fft_type: FftType, fft_lengths: List[int]):
|
|
"""PocketFFT kernel for CPU."""
|
|
shape = c.get_shape(a)
|
|
n = len(shape.dimensions())
|
|
dtype = shape.element_type()
|
|
builder = flatbuffers.Builder(128)
|
|
|
|
fft_lengths = list(fft_lengths)
|
|
assert len(fft_lengths) >= 1
|
|
assert len(fft_lengths) <= n, (fft_lengths, n)
|
|
|
|
forward = fft_type in (FftType.FFT, FftType.RFFT)
|
|
if fft_type == FftType.RFFT:
|
|
pocketfft_type = pd.PocketFftType.R2C
|
|
|
|
assert dtype in (np.float32, np.float64), dtype
|
|
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
|
|
pocketfft_dtype = (
|
|
pd.PocketFftDtype.COMPLEX64
|
|
if dtype == np.float32 else pd.PocketFftDtype.COMPLEX128)
|
|
|
|
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
|
|
shape, fft_lengths)
|
|
out_shape = list(shape.dimensions())
|
|
out_shape[-1] = out_shape[-1] // 2 + 1
|
|
|
|
elif fft_type == FftType.IRFFT:
|
|
pocketfft_type = pd.PocketFftType.C2R
|
|
assert np.issubdtype(dtype, np.complexfloating), dtype
|
|
|
|
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
|
|
pocketfft_dtype = (
|
|
pd.PocketFftDtype.COMPLEX64
|
|
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
|
|
|
|
assert list(shape.dimensions())[-len(fft_lengths):-1] == fft_lengths[:-1]
|
|
out_shape = list(shape.dimensions())
|
|
out_shape[-1] = fft_lengths[-1]
|
|
assert (out_shape[-1] // 2 + 1) == shape.dimensions()[-1]
|
|
else:
|
|
pocketfft_type = pd.PocketFftType.C2C
|
|
|
|
assert np.issubdtype(dtype, np.complexfloating), dtype
|
|
out_dtype = dtype
|
|
pocketfft_dtype = (
|
|
pd.PocketFftDtype.COMPLEX64
|
|
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
|
|
|
|
assert list(shape.dimensions())[-len(fft_lengths):] == fft_lengths, (
|
|
shape, fft_lengths)
|
|
out_shape = shape.dimensions()
|
|
|
|
# PocketFft does not allow size 0 dimensions.
|
|
if 0 in shape.dimensions() or 0 in out_shape:
|
|
return xla_client.ops.Broadcast(
|
|
xla_client.ops.Constant(c, np.array(0, dtype=out_dtype)), out_shape)
|
|
|
|
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
|
|
# C++ kernel to describe the FFT to perform.
|
|
pd.PocketFftDescriptorStartShapeVector(builder, n)
|
|
for d in reversed(
|
|
shape.dimensions() if fft_type != FftType.IRFFT else out_shape):
|
|
builder.PrependUint64(d)
|
|
if flatbuffers_version_2:
|
|
pocketfft_shape = builder.EndVector()
|
|
else:
|
|
pocketfft_shape = builder.EndVector(n)
|
|
|
|
pd.PocketFftDescriptorStartStridesInVector(builder, n)
|
|
stride = dtype.itemsize
|
|
for d in reversed(shape.dimensions()):
|
|
builder.PrependUint64(stride)
|
|
stride *= d
|
|
if flatbuffers_version_2:
|
|
strides_in = builder.EndVector()
|
|
else:
|
|
strides_in = builder.EndVector(n)
|
|
pd.PocketFftDescriptorStartStridesOutVector(builder, n)
|
|
stride = out_dtype.itemsize
|
|
for d in reversed(out_shape):
|
|
builder.PrependUint64(stride)
|
|
stride *= d
|
|
if flatbuffers_version_2:
|
|
strides_out = builder.EndVector()
|
|
else:
|
|
strides_out = builder.EndVector(n)
|
|
|
|
pd.PocketFftDescriptorStartAxesVector(builder, len(fft_lengths))
|
|
for d in range(len(fft_lengths)):
|
|
builder.PrependUint32(n - d - 1)
|
|
if flatbuffers_version_2:
|
|
axes = builder.EndVector()
|
|
else:
|
|
axes = builder.EndVector(len(fft_lengths))
|
|
|
|
scale = 1. if forward else (1. / np.prod(fft_lengths))
|
|
pd.PocketFftDescriptorStart(builder)
|
|
pd.PocketFftDescriptorAddDtype(builder, pocketfft_dtype)
|
|
pd.PocketFftDescriptorAddFftType(builder, pocketfft_type)
|
|
pd.PocketFftDescriptorAddShape(builder, pocketfft_shape)
|
|
pd.PocketFftDescriptorAddStridesIn(builder, strides_in)
|
|
pd.PocketFftDescriptorAddStridesOut(builder, strides_out)
|
|
pd.PocketFftDescriptorAddAxes(builder, axes)
|
|
pd.PocketFftDescriptorAddForward(builder, forward)
|
|
pd.PocketFftDescriptorAddScale(builder, scale)
|
|
descriptor = pd.PocketFftDescriptorEnd(builder)
|
|
builder.Finish(descriptor)
|
|
descriptor_bytes = builder.Output()
|
|
|
|
return xla_client.ops.CustomCallWithLayout(
|
|
c,
|
|
b"pocketfft",
|
|
operands=(
|
|
xla_client.ops.Constant(
|
|
c, np.frombuffer(descriptor_bytes, dtype=np.uint8)),
|
|
a,
|
|
),
|
|
shape_with_layout=xla_client.Shape.array_shape(
|
|
out_dtype, out_shape, tuple(range(n - 1, -1, -1))),
|
|
operand_shapes_with_layout=(
|
|
xla_client.Shape.array_shape(
|
|
np.dtype(np.uint8), (len(descriptor_bytes),), (0,)),
|
|
xla_client.Shape.array_shape(dtype, shape.dimensions(),
|
|
tuple(range(n - 1, -1, -1))),
|
|
))
|