mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API. Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work. PiperOrigin-RevId: 457460347