We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.
This is part 2 of a series (following #26097) for Pallas.
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
* Implement the context manager as a context manager class, rather than using @contextlib.contextmanager. It turns out the contextlib contextmanagers are rather slow.
* Fuse the four child context managers into a single context manager. This saves us a bunch of allocations.
* While we are here, also simplify the xla_metadata context manager to avoid its dual representation of the current metadata.
PiperOrigin-RevId: 719918121
* We don't need to keep a separate thread-local stack of objects: the config state already has a thread local.
* We don't need to keep an explicit stack of contexts at all: we can maintain it in the context manager frames.
* When checking for incompatible nested compute_ons, we can just check the current state: no need to look higher in the stack!
PiperOrigin-RevId: 719892989
Try to cover the tracing of almost all JAX higher-order
primitives. Some of the tests added show missing debug info,
marked with TODO. Fixes will come separately.
Had to expand the helper functions _check_tracers_and_jaxprs to
use regular expressions for matching because some debug info
still contains non-deterministic elements.
Key idea: if the argument to the context manager is None, then we don't need to touch any context state.
Also clean up the API by separating the "set a dict" from the "set kwargs" use cases.
PiperOrigin-RevId: 719628089
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.
PiperOrigin-RevId: 719280601
This change adds a C++ implementation that uses `xla::ifrt::RemapArrays` to
reorder shards of an array. This avoids creating intermediate single-device
arrays and accelerates reordering shards within `jax.device_put()`
implementation.
PiperOrigin-RevId: 718998621
We use dead code elimination (DCE) throughout JAX core to remove unused computations from Jaxprs. This typically works transparently when we're just using `lax` primitives, but opaque calls to `pallas_call` or `ffi_call` can't be cleaned up this way. For many kernels however, the author will know how to generate a more efficient call for specific patterns of used outputs, so it is useful to provide a mechanism for customizing this behavior.
In https://github.com/jax-ml/jax/pull/22735, I attempted to automatically tackle one specific example of this that comes up frequently, but there have been feature requests for a more general API. This version is bare bones and probably rough around the edges, but it could be a useful starting point for iteration.
PiperOrigin-RevId: 718950828