rocm_jax/jaxlib/ducc_fft.py
jax authors ad8c39ad7c Internal change
PiperOrigin-RevId: 513953876
2023-03-04 13:24:11 +00:00

150 lines
4.5 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
from .hlo_helpers import custom_call
from .cpu import _ducc_fft
import numpy as np
from jaxlib import xla_client
for _name, _value in _ducc_fft.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")
FftType = xla_client.FftType
_C2C = 0
_C2R = 1
_R2C = 2
def _ducc_fft_descriptor(
shape: List[int], dtype, fft_type: FftType, fft_lengths: List[int]
) -> Tuple[bytes, np.dtype, List[int]]:
n = len(shape)
assert len(fft_lengths) >= 1
assert len(fft_lengths) <= n, (fft_lengths, n)
forward = fft_type in (FftType.FFT, FftType.RFFT)
is_double = np.finfo(dtype).dtype == np.float64
if fft_type == FftType.RFFT:
ducc_fft_type = _R2C
assert dtype in (np.float32, np.float64), dtype
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
out_shape = list(shape)
out_shape[-1] = out_shape[-1] // 2 + 1
elif fft_type == FftType.IRFFT:
ducc_fft_type = _C2R
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
assert shape[-len(fft_lengths):-1] == fft_lengths[:-1]
out_shape = list(shape)
out_shape[-1] = fft_lengths[-1]
assert (out_shape[-1] // 2 + 1) == shape[-1]
else:
ducc_fft_type = _C2C
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = dtype
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
out_shape = shape
# PocketFft does not allow size 0 dimensions.
if 0 in shape or 0 in out_shape:
return b"", out_dtype, out_shape
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
# C++ kernel to describe the FFT to perform.
strides_in = []
stride = 1
for d in reversed(shape):
strides_in.append(stride)
stride *= d
strides_out = []
stride = 1
for d in reversed(out_shape):
strides_out.append(stride)
stride *= d
axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))]
scale = 1. if forward else (1. / np.prod(fft_lengths))
descriptor = _ducc_fft.ducc_fft_descriptor(
shape=shape if fft_type != FftType.IRFFT else out_shape,
is_double=is_double,
fft_type=ducc_fft_type,
fft_lengths=fft_lengths,
strides_in=list(reversed(strides_in)),
strides_out=list(reversed(strides_out)),
axes=axes,
forward=forward,
scale=scale)
return descriptor, out_dtype, out_shape
def ducc_fft_hlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
"""DUCC FFT kernel for CPU."""
a_type = ir.RankedTensorType(a.type)
n = len(a_type.shape)
fft_lengths = list(fft_lengths)
descriptor_bytes, out_dtype, out_shape = _ducc_fft_descriptor(
list(a_type.shape), dtype, fft_type, fft_lengths)
if out_dtype == np.float32:
out_type = ir.F32Type.get()
elif out_dtype == np.float64:
out_type = ir.F64Type.get()
elif out_dtype == np.complex64:
out_type = ir.ComplexType.get(ir.F32Type.get())
elif out_dtype == np.complex128:
out_type = ir.ComplexType.get(ir.F64Type.get())
else:
raise ValueError(f"Unknown output type {out_dtype}")
if 0 in a_type.shape or 0 in out_shape:
zero = hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(0, dtype=out_dtype), type=out_type))
return hlo.BroadcastOp(
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
u8_type = ir.IntegerType.get_unsigned(8)
descriptor = hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
layout = tuple(range(n - 1, -1, -1))
return custom_call(
"ducc_fft",
[ir.RankedTensorType.get(out_shape, out_type)],
[descriptor, a],
operand_layouts=[[0], layout],
result_layouts=[layout])