The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.
We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.
We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.
We obtain several improvements in presence of debug infos
in debug_info_test.py
* When current_mesh is Manual and aval mesh is Auto
* When current mesh is set and aval mesh is unset
* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.
* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.
This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.
`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.
PiperOrigin-RevId: 722868091
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.
This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).
In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.
One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.
This should eventually help us get rid of `spmd_axis_name` from `vmap`.
PiperOrigin-RevId: 721007659
* Implement the context manager as a context manager class, rather than using @contextlib.contextmanager. It turns out the contextlib contextmanagers are rather slow.
* Fuse the four child context managers into a single context manager. This saves us a bunch of allocations.
* While we are here, also simplify the xla_metadata context manager to avoid its dual representation of the current metadata.
PiperOrigin-RevId: 719918121
Key idea: if the argument to the context manager is None, then we don't need to touch any context state.
Also clean up the API by separating the "set a dict" from the "set kwargs" use cases.
PiperOrigin-RevId: 719628089
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).
Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.
In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.
Added more type declarations.
Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager
Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.
PiperOrigin-RevId: 716446406
This is part of a sequence of changes to ensure that the debugging information
is propagated properly.
Additional cleanup:
* Rename `result_paths` to `result_paths_thunk` in `TracingDebugInfo` to clarify the
difference from the similar field in `JaxprDebugInfo`
* Added more type declarations
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.
We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.
PiperOrigin-RevId: 713352542
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.
Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.
PiperOrigin-RevId: 707128096
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.
During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.
PiperOrigin-RevId: 704911253
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.
PiperOrigin-RevId: 700537898
* Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path)
* Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager.
* Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them.
* scan only allows `xs` where the 0th dim is full replicated i.e. None.
PiperOrigin-RevId: 699014167