The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.
We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.
We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.
We obtain several improvements in presence of debug infos
in debug_info_test.py
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.
This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).
In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.
One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.
We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.
We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
instantiate_const() must take and return a JaxprTracer.
Teach pytype that the Tracer returned by full_raise() must be an instance of the Tracer type associated with the Trace, using a Generic type.
PiperOrigin-RevId: 516554216
Add an "explicit_global_axis_size" arg. `global_axis` used to be set to `None`
when the user did not provide an explicit axis size. After this change,
`global_axis` should never be set to `None` internally, and always contain the
size of the global axis. It's still useful to thread the information that the
user has provided an explicit axis size so we can throw explicit errors in
`pxla` when explicit axis sizes are not allowed.
Why do we need to do this? We only go down the lowering path when calling
`pmap`s impl rule (while executing or final-style transforming), but not when
initial-style transforming. The global_axis size should be computed earlier,
such that it is available for initial-style transformations/primitives, e.g. if
we round-trip a multi-host pmap computation through make_jaxpr and eval_jaxpr.
We have tests for "initial-style transform of a `pmap`", but no such test for
_multi-host_ `pmap`! Alors, this bug went unnoticed.
#13545 makes `checkify` initial-style, and because `checkify-of-pmap` is a
valid way to check a `pmap`, an internal multi-host test uncovered this bug.
PiperOrigin-RevId: 499877003
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).
But after #7345, the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>