Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
When we do run_scoped[jaxpr, R1,R2], it can't be assumed that references
corresponding to R1 and R2 can be safely discharged. Sometimes they can (eg
Accumulator) but sometimes they can't (eg SMEM scratch). It should be up to the
lowering rule to do such discharging.
This further means that during lowering there is no guarantee that the
references will not be used/returned by nested scoped blocks so we also remove
that check.
PiperOrigin-RevId: 722137352
This CL lays the ground for a future CL that makes run_scoped discharge to not request the discharge of the temporary buffers it creates. This causes issues becausa
a) dma_start can't discharge some but not all its references
b) run_scoped() lowering depends on run_scoped discharge to remove the run_scoped operation (or it goes in an infinite loop).
PiperOrigin-RevId: 722126566
In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately.
Added small improvement to error messages.
PiperOrigin-RevId: 721473063
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.
PiperOrigin-RevId: 721412760
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.
This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).
In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.
One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
I also updated `to_dlpack` and `from_dlpack` to handle `KeyError` instead of `TypeError`, because I think `TypeError` was never actually raised.
PiperOrigin-RevId: 721052736
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.
This should eventually help us get rid of `spmd_axis_name` from `vmap`.
PiperOrigin-RevId: 721007659
This causes problems because internal code assumes it will not be modified. We replace this with an internal registration mechanism.
PiperOrigin-RevId: 721000907
Made several improvements to the debug info tests:
* added support for eager mode, which sometimes uses
different code paths for the debug info, e.g., for
`jvp(pmap)`. To check the debugging info in these cases we add
instrumentation to collect the lowered Jaxprs and MLIR modules right
after lowering, and we check the debugging information there.
* added support for checking for the presence of regular expressions
and strings in the lowered module, to check that the location
information and arg_names and result_paths is present. This
is now enabled only for a subset of the tests.
* simplified the pretty-printing of the arg_names and result_paths
in the debug info, to remove a layer of parentheses and string,
so that instead of `arg_names=("x", "y")` we now pretty-print
just `arg_names=x,y"
* added support for checking the provenance information in
leaked tracers
I came across this when working on an unrelated issue, but the explicit use of `finfo` was causing some `UserWarning`s, and it was really unnecessary.
PiperOrigin-RevId: 720691470
PR #25834 intended to dynamically choose the the partitioner API, but
it still applies the configuration value too early (it should only be
applied in __call__, not in def_partition and __call__).
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.
This is part 2 of a series (following #26097) for Pallas.