This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.
This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
The fix in #21032 was not correct because it assumed that the set of all mesh
axis names appearing in in_specs was an upper bound on the set of mesh axes
over which residuals could be device-varying. But collectives can introduce
device variance! So it's not an upper bound.
We track device variance when check_rep=True, but often people set
check_rep=False (e.g. when using pallas_call in a shard_map). So relying on our
device variance tracking would be limiting. That may be a decent long term
solution, if we can make it easy to annotate pallas_calls with device variance
information. But it's not a great short term one to unblock things.
So instead I temporrarily went with context sensitivity: instead of making
residuals sharded over all mesh.axis_names (as we did before these patches), we
make them sharded over all mesh axis names _excluding_ any spmd_axis_names in
our dynamic context (by looking at the traces in our trace stack). It's illegal
to mention any spmd_axis_names in collectives (indeed anywhere in the body of
the function being vmapped), but I don't think we check it.
TODO(mattjj): add more testing (maybe in follow-ups)
This handles three cases that came up when adopting this snippet (for
finding source code corresponding to API docs) [for neuralgcm](https://github.com/google-research/neuralgcm/pull/58):
1. documenting class method or attributes
2. documenting properties
3. documenting `jit` decorated methods
I'm not sure if case (1) or (2) comes up in the JAX docs, but case (3)
definitely does -- `jit` decorated functions like `jax.nn.relu`
[currently do not](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.relu.html)
have source code link.