mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 22:16:05 +00:00

Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together. PiperOrigin-RevId: 459566727