rocm_jax/jax/_src/lax/fft.py
Peter Hawkins 8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00

182 lines
6.8 KiB
Python

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
from typing import Union, Sequence
import numpy as np
from jax import lax
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
from jax._src.core import Primitive, is_constant_shape
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
__all__ = [
"fft",
"fft_p",
]
def _str_to_fft_type(s: str) -> xla_client.FftType:
if s in ("fft", "FFT"):
return xla_client.FftType.FFT
elif s in ("ifft", "IFFT"):
return xla_client.FftType.IFFT
elif s in ("rfft", "RFFT"):
return xla_client.FftType.RFFT
elif s in ("irfft", "IRFFT"):
return xla_client.FftType.IRFFT
else:
raise ValueError(f"Unknown FFT type '{s}'")
@partial(jit, static_argnums=(1, 2))
def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int]):
if isinstance(fft_type, str):
typ = _str_to_fft_type(fft_type)
elif isinstance(fft_type, xla_client.FftType):
typ = fft_type
else:
raise TypeError(f"Unknown FFT type value '{fft_type}'")
if typ == xla_client.FftType.RFFT:
if np.iscomplexobj(x):
raise ValueError("only real valued inputs supported for rfft")
x, = _promote_dtypes_inexact(x)
else:
x, = _promote_dtypes_complex(x)
if len(fft_lengths) == 0:
# XLA FFT doesn't support 0-rank.
return x
fft_lengths = tuple(fft_lengths)
return fft_p.bind(x, fft_type=typ, fft_lengths=fft_lengths)
def _fft_impl(x, fft_type, fft_lengths):
return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths)
_complex_dtype = lambda dtype: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
_real_dtype = lambda dtype: np.finfo(dtype).dtype
_is_even = lambda x: x % 2 == 0
def fft_abstract_eval(x, fft_type, fft_lengths):
if len(fft_lengths) > x.ndim:
raise ValueError(f"FFT input shape {x.shape} must have at least as many "
f"input dimensions as fft_lengths {fft_lengths}.")
if fft_type == xla_client.FftType.RFFT:
if x.shape[-len(fft_lengths):] != fft_lengths:
raise ValueError(f"RFFT input shape {x.shape} minor dimensions must "
f"be equal to fft_lengths {fft_lengths}")
shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1]
+ (fft_lengths[-1] // 2 + 1,))
dtype = _complex_dtype(x.dtype)
elif fft_type == xla_client.FftType.IRFFT:
if x.shape[-len(fft_lengths):-1] != fft_lengths[:-1]:
raise ValueError(f"IRFFT input shape {x.shape} minor dimensions must "
"be equal to all except the last fft_length, got "
f"{fft_lengths=}")
shape = x.shape[:-len(fft_lengths)] + fft_lengths
dtype = _real_dtype(x.dtype)
else:
if x.shape[-len(fft_lengths):] != fft_lengths:
raise ValueError(f"FFT input shape {x.shape} minor dimensions must "
f"be equal to fft_lengths {fft_lengths}")
shape = x.shape
dtype = x.dtype
return x.update(shape=shape, dtype=dtype)
def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
return [
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
mlir.dense_int_elements(fft_lengths)).result
]
def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778")
x_aval, = ctx.avals_in
return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]
def _naive_rfft(x, fft_lengths):
y = fft(x, xla_client.FftType.FFT, fft_lengths)
n = fft_lengths[-1]
return y[..., : n//2 + 1]
@partial(jit, static_argnums=1)
def _rfft_transpose(t, fft_lengths):
# The transpose of RFFT can't be expressed only in terms of irfft. Instead of
# manually building up larger twiddle matrices (which would increase the
# asymptotic complexity and is also rather complicated), we rely JAX to
# transpose a naive RFFT implementation.
dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype))
transpose = linear_transpose(
partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primal)
result, = transpose(t)
assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype)
return result
def _irfft_transpose(t, fft_lengths):
# The transpose of IRFFT is the RFFT of the cotangent times a scaling
# factor and a mask. The mask scales the cotangent for the Hermitian
# symmetric components of the RFFT by a factor of two, since these components
# are de-duplicated in the RFFT.
x = fft(t, xla_client.FftType.RFFT, fft_lengths)
n = x.shape[-1]
is_odd = fft_lengths[-1] % 2
full = partial(lax.full_like, t, dtype=x.dtype)
mask = lax.concatenate(
[full(1.0, shape=(1,)),
full(2.0, shape=(n - 2 + is_odd,)),
full(1.0, shape=(1 - is_odd,))],
dimension=0)
scale = 1 / math.prod(fft_lengths)
out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
# Use JAX's convention for complex gradients
# https://github.com/google/jax/issues/6223#issuecomment-807740707
return lax.conj(out)
def _fft_transpose_rule(t, operand, fft_type, fft_lengths):
if fft_type == xla_client.FftType.RFFT:
result = _rfft_transpose(t, fft_lengths)
elif fft_type == xla_client.FftType.IRFFT:
result = _irfft_transpose(t, fft_lengths)
else:
result = fft(t, fft_type, fft_lengths)
return result,
def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return fft(x, fft_type, fft_lengths), 0
fft_p = Primitive('fft')
fft_p.def_impl(_fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
mlir.register_lowering(fft_p, _fft_lowering)
ad.deflinear2(fft_p, _fft_transpose_rule)
batching.primitive_batchers[fft_p] = _fft_batching_rule
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')