I suspect in the past lack of source info meant that the function also has
no signature, but this is no longer the case.
I also removed an unused parameter from ``explain_tracing_cache_miss`` as
a drive by change.
This is a follow up to #22269.
The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.
The renames are
* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
Before this information was lost in the roundtrip via `mlir.lower_fun` -> `jaxpr_subcomp`. But now since it's on the jaxpr equations, the information is preserved in jaxpr_subcomp as we enter into each eqn's ctx.
Fixes: https://github.com/google/jax/issues/21061
PiperOrigin-RevId: 636940742
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.
`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.
PiperOrigin-RevId: 634909918
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.
This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.
PiperOrigin-RevId: 623015500
This is an attempt to re-land #19819 aka cl/607570860 after a small number of
performance regressions.
As before, the main changes are:
1. simplify the scan impl that we trace through to get the lowering, and
2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body
jaxpr we already have in hand.
The main motivation was (2), but (1) seems like a useful win too.
The way we achieve (2) is with a new trick: in our scan_impl function, which is
only ever traced to a jaxpr, instead of calling
`core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive
`eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging
rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and
that rule just generates a call into the jaxpr of interest. Therefore we will
not traverse into the jaxpr just to rebuild it inline (as before).
The code in #19819 was simpler in that it avoided reshapes, concats, and
un-concats. But it caused at least one apparent performance regression (an XLA
bug?) and it was unrelated to the original goal of reducing tracing time. So
here we just land the trace time improvement.