[JAX] Stop using a custom ducc kernel: instead just emit an Fft HLO operation and let XLA emit the call to ducc.

XLA now calls ducc itself as of da67903a4c, so we don't need a custom call in JAX any more. In addition, the DUCC call from XLA receives a thread pool and is parallelized.

Fixes https://github.com/google/jax/issues/14664

PiperOrigin-RevId: 579829580
This commit is contained in:
Peter Hawkins 2023-11-06 06:54:36 -08:00 committed by jax authors
parent 1e810983fa
commit 390022a227

View File

@ -29,6 +29,7 @@ from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib import ducc_fft
from jax._src.lib import version as jaxlib_version
from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact
@ -257,4 +258,8 @@ fft_p.def_abstract_eval(fft_abstract_eval)
mlir.register_lowering(fft_p, _fft_lowering)
ad.deflinear2(fft_p, _fft_transpose_rule)
batching.primitive_batchers[fft_p] = _fft_batching_rule
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
# TODO(phawkins): when jaxlib 0.4.21 is the minimum, use XLA's FFT lowering
# always on CPU. At that point, we can also delete the DUCC FFT kernel from JAX.
if xla_extension_version < 211:
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')