mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.) For the benchmark in #2952 on my workstation: Before: ``` 907.3490574884647 max: 4.362646594533903e-08 mean: 6.237288307614869e-09 min: 0.0 numpy fft execution time [ms]: 37.088446617126465 jax fft execution time [ms]: 74.93342399597168 ``` After: ``` 907.3490574884647 max: 1.9057386696477137e-12 mean: 3.9326737908882566e-13 min: 0.0 numpy fft execution time [ms]: 37.756404876708984 jax fft execution time [ms]: 28.128278255462646 ``` Fixes https://github.com/google/jax/issues/2952 PiperOrigin-RevId: 338743753