Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483
For future reference, this can be done via
python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
while IFS=: read file line rest; do
echo "$file:$line";
gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
done < /tmp/unused.txt
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.
We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.
We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.
We obtain several improvements in presence of debug infos
in debug_info_test.py
* When current_mesh is Manual and aval mesh is Auto
* When current mesh is set and aval mesh is unset
* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.
* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.
This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.
`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.
PiperOrigin-RevId: 722868091
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.
This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).
In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.
One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.
This should eventually help us get rid of `spmd_axis_name` from `vmap`.
PiperOrigin-RevId: 721007659
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.
* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.
We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.
PiperOrigin-RevId: 717588253
Following https://github.com/jax-ml/jax/pull/25916 there were a few TODOs
left in the code to remove api_util.debug_info and replace the
one remaining use with api_util.tracing_debug_info.
PiperOrigin-RevId: 717583667
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).
Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.
In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.
Added more type declarations.
Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.
PiperOrigin-RevId: 716771872
* mesh_cast only works when the axis types between src and dst mesh changes. Hence the name!
* No explicit data movement is allowed. Specs containing axes that are visible cannot be different between src and dst shardings.
* src and dst mesh axis_names and axis_sizes should be the same.
TODO: Make `shardings` parameter to `mesh_cast` optional.
PiperOrigin-RevId: 716727084
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager
Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.
PiperOrigin-RevId: 716446406
This is part of a sequence of changes to ensure that the debugging information
is propagated properly.
Additional cleanup:
* Rename `result_paths` to `result_paths_thunk` in `TracingDebugInfo` to clarify the
difference from the similar field in `JaxprDebugInfo`
* Added more type declarations
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.
We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.
PiperOrigin-RevId: 713352542
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.
Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.
PiperOrigin-RevId: 707128096
This is useful for `jax.export`, e.g., for cross-platform export when we do not have access to the actual devices for which this computation is lowered.
PiperOrigin-RevId: 705764178
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.