mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

This change improves the stability and backward compatibility of Pallas Triton calls, because unlike PTX, the Triton dialect has no stability guarantees and does change in practice. See #25196. A few notes * Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead, compilation is done via a new PjRt extension, which uses its own compilation pipeline mirrored after the one in the Triton Python bindings. * The implementation of the old custom call used by Pallas Triton is deprecated and will be removed after 6 months as per [compatibility guarantees] [*] [*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees PiperOrigin-RevId: 722773884