Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.
I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.
Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
This follows after #26078, #26313, #26348, adding `debug_info` to more calls to `lu.wrap_init`.
As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.
These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest,
api_test.py:CustomVmapTest and api_test.py:RematTest.
This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.
As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.
We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.
We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.
We obtain several improvements in presence of debug infos
in debug_info_test.py
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
This test is sometimes reporting 4 warnings, probably because of tracing cache hits. To be correct, this test probably needs to use its own unique functions that are not shared with other test cases.
PiperOrigin-RevId: 721571459
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.
This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).
In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.
One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
Made several improvements to the debug info tests:
* added support for eager mode, which sometimes uses
different code paths for the debug info, e.g., for
`jvp(pmap)`. To check the debugging info in these cases we add
instrumentation to collect the lowered Jaxprs and MLIR modules right
after lowering, and we check the debugging information there.
* added support for checking for the presence of regular expressions
and strings in the lowered module, to check that the location
information and arg_names and result_paths is present. This
is now enabled only for a subset of the tests.
* simplified the pretty-printing of the arg_names and result_paths
in the debug info, to remove a layer of parentheses and string,
so that instead of `arg_names=("x", "y")` we now pretty-print
just `arg_names=x,y"
* added support for checking the provenance information in
leaked tracers
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.
This is part 2 of a series (following #26097) for Pallas.
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
Try to cover the tracing of almost all JAX higher-order
primitives. Some of the tests added show missing debug info,
marked with TODO. Fixes will come separately.
Had to expand the helper functions _check_tracers_and_jaxprs to
use regular expressions for matching because some debug info
still contains non-deterministic elements.
Debugging info is needed for error messages, and for
lowering. For the former, we need debug info inside
tracers. For the latter, inside Jaxprs. We add a
new set of tests that intentionally leak tracers while
tracing and then we check that the tracers have the
expected debug info. We also form Jaxprs and we
check that they have the expected debug info.
We uncovered a few missing debug infos, those are
marked with TODO.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
Created debug_info_test.py and moved there some of the
tests involving debug_info. In the future we will put here
more tests for debugging info, and their helper functions.