rocm_jax/jax/_src/lax/fft.py
Roy Frostig d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00

185 lines
7.0 KiB
Python

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Union, Sequence
import numpy as np
from jax.interpreters import mlir
from jax.interpreters import xla
from jax import lax
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
from jax._src.core import Primitive, is_constant_shape
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
from jax._src.util import prod
__all__ = [
"fft",
"fft_p",
]
def _str_to_fft_type(s: str) -> xla_client.FftType:
if s in ("fft", "FFT"):
return xla_client.FftType.FFT
elif s in ("ifft", "IFFT"):
return xla_client.FftType.IFFT
elif s in ("rfft", "RFFT"):
return xla_client.FftType.RFFT
elif s in ("irfft", "IRFFT"):
return xla_client.FftType.IRFFT
else:
raise ValueError(f"Unknown FFT type '{s}'")
@partial(jit, static_argnums=(1, 2))
def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int]):
if isinstance(fft_type, str):
typ = _str_to_fft_type(fft_type)
elif isinstance(fft_type, xla_client.FftType):
typ = fft_type
else:
raise TypeError(f"Unknown FFT type value '{fft_type}'")
if typ == xla_client.FftType.RFFT:
if np.iscomplexobj(x):
raise ValueError("only real valued inputs supported for rfft")
x, = _promote_dtypes_inexact(x)
else:
x, = _promote_dtypes_complex(x)
if len(fft_lengths) == 0:
# XLA FFT doesn't support 0-rank.
return x
fft_lengths = tuple(fft_lengths)
return fft_p.bind(x, fft_type=typ, fft_lengths=fft_lengths)
def _fft_impl(x, fft_type, fft_lengths):
return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths)
_complex_dtype = lambda dtype: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
_real_dtype = lambda dtype: np.finfo(dtype).dtype
_is_even = lambda x: x % 2 == 0
def fft_abstract_eval(x, fft_type, fft_lengths):
if len(fft_lengths) > x.ndim:
raise ValueError(f"FFT input shape {x.shape} must have at least as many "
f"input dimensions as fft_lengths {fft_lengths}.")
if fft_type == xla_client.FftType.RFFT:
if x.shape[-len(fft_lengths):] != fft_lengths:
raise ValueError(f"RFFT input shape {x.shape} minor dimensions must "
f"be equal to fft_lengths {fft_lengths}")
shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1]
+ (fft_lengths[-1] // 2 + 1,))
dtype = _complex_dtype(x.dtype)
elif fft_type == xla_client.FftType.IRFFT:
if x.shape[-len(fft_lengths):-1] != fft_lengths[:-1]:
raise ValueError(f"IRFFT input shape {x.shape} minor dimensions must "
"be equal to all except the last fft_length, got "
f"{fft_lengths=}")
shape = x.shape[:-len(fft_lengths)] + fft_lengths
dtype = _real_dtype(x.dtype)
else:
if x.shape[-len(fft_lengths):] != fft_lengths:
raise ValueError(f"FFT input shape {x.shape} minor dimensions must "
f"be equal to fft_lengths {fft_lengths}")
shape = x.shape
dtype = x.dtype
return x.update(shape=shape, dtype=dtype)
def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
return [
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
mlir.dense_int_elements(fft_lengths)).result
]
def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778")
x_aval, = ctx.avals_in
if xla_client.mlir_api_version < 41:
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]
else:
return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]
def _naive_rfft(x, fft_lengths):
y = fft(x, xla_client.FftType.FFT, fft_lengths)
n = fft_lengths[-1]
return y[..., : n//2 + 1]
@partial(jit, static_argnums=1)
def _rfft_transpose(t, fft_lengths):
# The transpose of RFFT can't be expressed only in terms of irfft. Instead of
# manually building up larger twiddle matrices (which would increase the
# asymptotic complexity and is also rather complicated), we rely JAX to
# transpose a naive RFFT implementation.
dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype))
transpose = linear_transpose(
partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primal)
result, = transpose(t)
assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype)
return result
def _irfft_transpose(t, fft_lengths):
# The transpose of IRFFT is the RFFT of the cotangent times a scaling
# factor and a mask. The mask scales the cotangent for the Hermitian
# symmetric components of the RFFT by a factor of two, since these components
# are de-duplicated in the RFFT.
x = fft(t, xla_client.FftType.RFFT, fft_lengths)
n = x.shape[-1]
is_odd = fft_lengths[-1] % 2
full = partial(lax.full_like, t, dtype=x.dtype)
mask = lax.concatenate(
[full(1.0, shape=(1,)),
full(2.0, shape=(n - 2 + is_odd,)),
full(1.0, shape=(1 - is_odd,))],
dimension=0)
scale = 1 / prod(fft_lengths)
out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
# Use JAX's convention for complex gradients
# https://github.com/google/jax/issues/6223#issuecomment-807740707
return lax.conj(out)
def _fft_transpose_rule(t, operand, fft_type, fft_lengths):
if fft_type == xla_client.FftType.RFFT:
result = _rfft_transpose(t, fft_lengths)
elif fft_type == xla_client.FftType.IRFFT:
result = _irfft_transpose(t, fft_lengths)
else:
result = fft(t, fft_type, fft_lengths)
return result,
def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return fft(x, fft_type, fft_lengths), 0
fft_p = Primitive('fft')
fft_p.def_impl(_fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
mlir.register_lowering(fft_p, _fft_lowering)
ad.deflinear2(fft_p, _fft_transpose_rule)
batching.primitive_batchers[fft_p] = _fft_batching_rule
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')