This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.
See #25196.
A few notes
* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
compilation is done via a new PjRt extension, which uses its own compilation
pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
deprecated and will be removed after 6 months as per
[compatibility guarantees] [*]
[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees
PiperOrigin-RevId: 722773884
Previously, it was necessary to list all dtypes explicitly, which is why
we had separate fallback rules for float16 and bfloat16 for some functions.
PiperOrigin-RevId: 722729554
The user has access only to accumulator references and they can't pass them as caries to loops. However when they are discharged these accumulators become values and become part of the carry. Before this CL this would surprise the loop lowering code.
This was never a problem for pallas mgpu until we added pipelining loops instead of sequential bloc axes.
PiperOrigin-RevId: 722495749
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
When we do run_scoped[jaxpr, R1,R2], it can't be assumed that references
corresponding to R1 and R2 can be safely discharged. Sometimes they can (eg
Accumulator) but sometimes they can't (eg SMEM scratch). It should be up to the
lowering rule to do such discharging.
This further means that during lowering there is no guarantee that the
references will not be used/returned by nested scoped blocks so we also remove
that check.
PiperOrigin-RevId: 722137352
This CL lays the ground for a future CL that makes run_scoped discharge to not request the discharge of the temporary buffers it creates. This causes issues becausa
a) dma_start can't discharge some but not all its references
b) run_scoped() lowering depends on run_scoped discharge to remove the run_scoped operation (or it goes in an infinite loop).
PiperOrigin-RevId: 722126566