mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[JAX] Stop using a custom ducc kernel: instead just emit an Fft HLO operation and let XLA emit the call to ducc.
XLA now calls ducc itself as of da67903a4c
, so we don't need a custom call in JAX any more. In addition, the DUCC call from XLA receives a thread pool and is parallelized.
Fixes https://github.com/google/jax/issues/14664
PiperOrigin-RevId: 579829580
This commit is contained in:
parent
1e810983fa
commit
390022a227
@ -29,6 +29,7 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import ducc_fft
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact
|
||||
@ -257,4 +258,8 @@ fft_p.def_abstract_eval(fft_abstract_eval)
|
||||
mlir.register_lowering(fft_p, _fft_lowering)
|
||||
ad.deflinear2(fft_p, _fft_transpose_rule)
|
||||
batching.primitive_batchers[fft_p] = _fft_batching_rule
|
||||
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
|
||||
|
||||
# TODO(phawkins): when jaxlib 0.4.21 is the minimum, use XLA's FFT lowering
|
||||
# always on CPU. At that point, we can also delete the DUCC FFT kernel from JAX.
|
||||
if xla_extension_version < 211:
|
||||
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
|
||||
|
Loading…
x
Reference in New Issue
Block a user