1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 07:16:07 +00:00

[better_errors] Merge the JaxprDebugInfo and TracingDebugInfo into core.DebugInfo

Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
This commit is contained in:
George Necula 2025-01-31 22:23:20 +02:00
parent de48ce2a4c
commit c70de6deed
24 changed files with 159 additions and 176 deletions

@ -323,7 +323,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
@wraps(fun) @wraps(fun)
@api_boundary @api_boundary
def fun_remat(*args, **kwargs): def fun_remat(*args, **kwargs):
debug = api_util.tracing_debug_info( debug = api_util.debug_info(
"checkpoint / remat", fun, "checkpoint / remat", fun,
args, kwargs, static_argnums=static_argnums) args, kwargs, static_argnums=static_argnums)
fun_, args = _remat_static_argnums(fun, static_argnums, args) fun_, args = _remat_static_argnums(fun, static_argnums, args)
@ -418,7 +418,7 @@ _dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached)
def _trace_to_jaxpr(fun: Callable, def _trace_to_jaxpr(fun: Callable,
in_tree: PyTreeDef, in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue], in_avals: Sequence[core.AbstractValue],
debug: lu.TracingDebugInfo debug: core.DebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]: ) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree) flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
try: try:
@ -447,7 +447,7 @@ def saved_residuals(f: Callable,
args, kwargs = tree_unflatten(in_tree, args) args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs) return f(*args, **kwargs)
debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs) debug_info = api_util.debug_info("saved_residuals", f, args, kwargs)
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1], out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
return_shape=True)(*in_leaves) return_shape=True)(*in_leaves)
assert isinstance(out, tuple) assert isinstance(out, tuple)

@ -57,11 +57,11 @@ from jax._src import pjit
from jax._src import xla_bridge as xb from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray
from jax._src.api_util import ( from jax._src.api_util import (
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple, rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, tracing_debug_info, apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes) result_paths, flat_out_axes)
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
@ -452,7 +452,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0,
raise TypeError(f"differentiating with respect to {argnums=} requires at least " raise TypeError(f"differentiating with respect to {argnums=} requires at least "
f"{max_argnum + 1} positional arguments to be passed by the caller, " f"{max_argnum + 1} positional arguments to be passed by the caller, "
f"but got only {len(args)} positional arguments.") f"but got only {len(args)} positional arguments.")
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs) dbg = debug_info('value_and_grad', fun, args, kwargs)
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg) f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
f_partial, dyn_args = argnums_partial(f, argnums, args, f_partial, dyn_args = argnums_partial(f, argnums, args,
@ -1426,7 +1426,7 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0: if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.") raise ValueError("'devices' argument to pmap must be non-empty, or None.")
dbg = tracing_debug_info( dbg = debug_info(
"pmap", fun, args, kwargs, "pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple) static_argnums=static_broadcasted_tuple)

@ -31,7 +31,6 @@ from jax._src.tree_util import (
prefix_errors) prefix_errors)
from jax._src.tree_util import _replace_nones from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable, safe_zip) Unhashable, safe_zip)
from jax._src import traceback_util from jax._src import traceback_util
@ -582,7 +581,7 @@ def api_hook(fun, tag: str):
return fun return fun
def tracing_debug_info( def debug_info(
traced_for: str, traced_for: str,
fun: Callable, fun: Callable,
args: Sequence[Any], args: Sequence[Any],
@ -594,14 +593,14 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing. # TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None, sourceinfo: str | None = None,
signature: inspect.Signature | None = None, signature: inspect.Signature | None = None,
) -> TracingDebugInfo: ) -> core.DebugInfo:
if sourceinfo is None: if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun) sourceinfo = fun_sourceinfo(fun)
if signature is None: if signature is None:
signature = fun_signature(fun) signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums, arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames) static_argnames)
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
def fun_signature(fun: Callable) -> inspect.Signature | None: def fun_signature(fun: Callable) -> inspect.Signature | None:
@ -619,7 +618,7 @@ _fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
# TODO(mattjj): make this function internal to this module # TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str: def fun_sourceinfo(fun: Callable) -> str:
# See TracingDebugInfo.fun_src_info # See DebugInfo.fun_src_info
res = getattr(fun, "__fun_sourceinfo__", None) res = getattr(fun, "__fun_sourceinfo__", None)
if res is not None: return res if res is not None: return res
while isinstance(fun, partial): while isinstance(fun, partial):
@ -684,20 +683,19 @@ def result_paths(_fun, _store, *args, **kwargs):
# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr # TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr, def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None, debug: core.DebugInfo | None,
result_paths: tuple[str, ...] | None = None, result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr: ) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths.""" """Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None: if debug is None:
return jaxpr return jaxpr
# TODO(necula): re-enable this safety check # TODO(necula): re-enable this safety check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None) # assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is None: if result_paths is not None:
result_paths = trace_debug.result_paths_thunk() # type: ignore debug = debug._replace(result_paths=tuple(result_paths))
debug_info = core.JaxprDebugInfo( else:
trace_debug.traced_for, trace_debug.func_src_info, debug = debug.resolve_result_paths()
trace_debug.arg_names, tuple(result_paths)) # type: ignore return jaxpr.replace(debug_info=debug)
return jaxpr.replace(debug_info=debug_info)
def hoist_obj_attrs(f, flat_args): def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], [] idxs, objs, flat_args_ = [], [], []

@ -1202,7 +1202,7 @@ def checkify(f: Callable[..., Out],
in_tree = jtu.tree_structure(((), {})) in_tree = jtu.tree_structure(((), {}))
closed_f = lambda: f(*args, **kwargs) closed_f = lambda: f(*args, **kwargs)
# stage: # stage:
debug = api_util.tracing_debug_info("checkify", f, args, kwargs) debug = api_util.debug_info("checkify", f, args, kwargs)
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f, fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
debug_info=debug), debug_info=debug),
in_tree) in_tree)

@ -82,31 +82,7 @@ EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects no_effects: Effects = effects.no_effects
# TODO(necula): make this an extension of TracingDebugInfo DebugInfo = lu.DebugInfo
class JaxprDebugInfo(NamedTuple):
# An extension of lu.TracingDebugInfo; see comments there
traced_for: str
func_src_info: str
arg_names: tuple[str | None, ...]
# This is formed after tracing, when we have concrete `result_paths`
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
return self.arg_names
else:
# TODO(necula): this should not happen
return (None,) * expected
def safe_result_paths(self, expected: int) -> tuple[str | None, ...]:
"""Get the result_paths with a safety check."""
if len(self.result_paths) == expected:
return self.result_paths
else:
# TODO(necula): this should not happen
return ("",) * expected
class Jaxpr: class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
@ -117,7 +93,7 @@ class Jaxpr:
_outvars: list[Atom] _outvars: list[Atom]
_eqns: list[JaxprEqn] _eqns: list[JaxprEqn]
_effects: Effects _effects: Effects
_debug_info: JaxprDebugInfo | None _debug_info: DebugInfo | None
@property @property
def constvars(self) -> list[Var]: def constvars(self) -> list[Var]:
@ -140,13 +116,13 @@ class Jaxpr:
return self._effects return self._effects
@property @property
def debug_info(self) -> JaxprDebugInfo | None: def debug_info(self) -> DebugInfo | None:
return self._debug_info return self._debug_info
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn], outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects, effects: Effects = no_effects,
debug_info: JaxprDebugInfo | None = None): debug_info: DebugInfo | None = None):
""" """
Args: Args:
constvars: list of variables introduced for constants. Array constants are constvars: list of variables introduced for constants. Array constants are
@ -157,14 +133,14 @@ class Jaxpr:
eqns: list of equations. eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation. union of the effects for each equation.
debug_info: optional JaxprDebugInfo. debug_info: optional DebugInfo.
""" """
self._constvars = list(constvars) self._constvars = list(constvars)
self._invars = list(invars) self._invars = list(invars)
self._outvars = list(outvars) self._outvars = list(outvars)
self._eqns = list(eqns) self._eqns = list(eqns)
self._effects = effects self._effects = effects
self._debug_info = debug_info self._debug_info = debug_info and debug_info.resolve_result_paths()
# TODO(necula): re-enable these safety checks # TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars) # assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) # assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)

@ -147,7 +147,7 @@ class custom_vmap:
raise AttributeError( raise AttributeError(
f"No batching rule defined for custom_vmap function {fun_name} " f"No batching rule defined for custom_vmap function {fun_name} "
"using def_vmap.") "using def_vmap.")
debug = api_util.tracing_debug_info("custom_vmap", self.fun, args, {}) debug = api_util.debug_info("custom_vmap", self.fun, args, {})
args_flat, in_tree = tree_flatten(args) args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs( flat_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(self.fun, debug_info=debug), lu.wrap_init(self.fun, debug_info=debug),

@ -127,12 +127,12 @@ class custom_dce:
"def_dce." "def_dce."
) )
rule_name = util.fun_name(self.dce_rule) rule_name = util.fun_name(self.dce_rule)
debug = api_util.tracing_debug_info("custom_dce", self.fun, debug = api_util.debug_info("custom_dce", self.fun,
args, {}, args, {},
static_argnums=self.static_argnums) static_argnums=self.static_argnums)
debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule, debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule,
args, {}, args, {},
static_argnums=self.static_argnums) static_argnums=self.static_argnums)
args = api_util.resolve_kwargs(self.fun, args, kwargs) args = api_util.resolve_kwargs(self.fun, args, kwargs)
if self.static_argnums: if self.static_argnums:
static_argnums = set(self.static_argnums) static_argnums = set(self.static_argnums)

@ -468,9 +468,9 @@ class custom_partitioning:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
args = _resolve_kwargs(self.fun, args, kwargs) args = _resolve_kwargs(self.fun, args, kwargs)
debug = api_util.tracing_debug_info("custom_partitioning", self.fun, debug = api_util.debug_info("custom_partitioning", self.fun,
args, kwargs, args, kwargs,
static_argnums=self.static_argnums) static_argnums=self.static_argnums)
if self.static_argnums: if self.static_argnums:
static_argnums = set(self.static_argnums) static_argnums = set(self.static_argnums)
args = tuple(x if i in static_argnums else x for i, x in enumerate(args)) args = tuple(x if i in static_argnums else x for i, x in enumerate(args))

@ -147,7 +147,7 @@ def _linearize_jaxpr(
jaxpr: core.ClosedJaxpr, jaxpr: core.ClosedJaxpr,
nonzeros: tuple[bool, ...] nonzeros: tuple[bool, ...]
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]: ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]:
dbg = lu.TracingDebugInfo.from_jaxpr(jaxpr) dbg = jaxpr.jaxpr.debug_info
primal_trace = pe.DynamicJaxprTrace(dbg) primal_trace = pe.DynamicJaxprTrace(dbg)
tangent_trace = pe.DynamicJaxprTrace(dbg) tangent_trace = pe.DynamicJaxprTrace(dbg)
lin_trace = LinearizeTrace(primal_trace, tangent_trace) lin_trace = LinearizeTrace(primal_trace, tangent_trace)

@ -42,7 +42,6 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext) InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef from jax._src.state.types import AbstractRef
from jax._src import tree_util
from jax._src.tree_util import (PyTreeDef, treedef_tuple, from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure) tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@ -932,7 +931,7 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
in_unknowns: Sequence[bool], in_unknowns: Sequence[bool],
instantiate: bool | Sequence[bool]): instantiate: bool | Sequence[bool]):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
debug_info=lu.TracingDebugInfo.from_jaxpr(jaxpr)) debug_info=jaxpr.jaxpr.debug_info)
cell = [] cell = []
def fun(*known_vals_in): def fun(*known_vals_in):
@ -1334,10 +1333,11 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr: def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b] outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
dbg = jaxpr.debug_info and core.JaxprDebugInfo( dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names, jaxpr.debug_info.arg_names,
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b)) tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
used_outputs) if b))
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg) new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr) config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr return new_jaxpr
@ -1422,10 +1422,12 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
eqns = new_eqns[::-1] eqns = new_eqns[::-1]
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns) jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
dbg = jaxpr.debug_info and core.JaxprDebugInfo( dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b), tuple(v for v, b in zip(jaxpr.debug_info.safe_arg_names(len(used_inputs)),
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b)) used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
used_outputs) if b))
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg) new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr) config.enable_checks.value and core.check_jaxpr(new_jaxpr)
@ -1623,9 +1625,9 @@ class JaxprStackFrame:
attrs_tracked: list[tuple[Any, str]] attrs_tracked: list[tuple[Any, str]]
attrs_inits: list attrs_inits: list
attrs_vars: list[Var] attrs_vars: list[Var]
debug_info: lu.TracingDebugInfo | None debug_info: core.DebugInfo | None
def __init__(self, debug_info: lu.TracingDebugInfo | None): def __init__(self, debug_info: core.DebugInfo | None):
self.gensym = core.gensym() self.gensym = core.gensym()
self.tracer_to_var = {} self.tracer_to_var = {}
self.constid_to_tracer = {} self.constid_to_tracer = {}
@ -1809,7 +1811,7 @@ def _inline_literals(
class DynamicJaxprTrace(core.Trace): class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",) __slots__ = ("frame",)
def __init__(self, debug_info: lu.TracingDebugInfo | None): def __init__(self, debug_info: core.DebugInfo | None):
self.frame = JaxprStackFrame(debug_info) self.frame = JaxprStackFrame(debug_info)
def invalidate(self): def invalidate(self):
@ -2114,7 +2116,7 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
def trace_to_jaxpr_dynamic( def trace_to_jaxpr_dynamic(
fun: lu.WrappedFun, fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue], in_avals: Sequence[AbstractValue],
debug_info: lu.TracingDebugInfo | None = None, debug_info: core.DebugInfo | None = None,
*, *,
keep_inputs: list[bool] | None = None, keep_inputs: list[bool] | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any], ) -> tuple[Jaxpr, list[AbstractValue], list[Any],
@ -2137,7 +2139,7 @@ def trace_to_jaxpr_dynamic(
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
def _check_no_returned_refs( def _check_no_returned_refs(
dbg: lu.TracingDebugInfo | None, dbg: core.DebugInfo | None,
out_tracers: Sequence[DynamicJaxprTracer] out_tracers: Sequence[DynamicJaxprTracer]
) -> None: ) -> None:
if not config.mutable_array_checks.value: return if not config.mutable_array_checks.value: return
@ -2148,10 +2150,8 @@ def _check_no_returned_refs(
raise ValueError( raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, " f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.") "but mutable array references cannot be returned.")
loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
if (dbg.result_paths_thunk and loc = f' at output tree path {result_paths[i]}'
(ls := dbg.result_paths_thunk()) and
ls[i]) else '')
frame = t._trace.frame frame = t._trace.frame
v = frame.tracer_to_var.get(id(t)) v = frame.tracer_to_var.get(id(t))
eqn = next((e for e in frame.eqns if v in e.outvars), None) eqn = next((e for e in frame.eqns if v in e.outvars), None)
@ -2172,7 +2172,7 @@ def _check_no_returned_refs(
@profiler.annotate_function @profiler.annotate_function
def trace_to_jaxpr_dynamic2( def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None fun: lu.WrappedFun, debug_info: core.DebugInfo | None = None
) -> tuple[Jaxpr, OutputType, list[Any]]: ) -> tuple[Jaxpr, OutputType, list[Any]]:
trace = DynamicJaxprTrace(debug_info) trace = DynamicJaxprTrace(debug_info)

@ -652,7 +652,7 @@ class ParallelCallableInfo:
in_axes: Iterable[int | None] in_axes: Iterable[int | None]
out_axes_thunk: Callable[[], Sequence[int | None]] out_axes_thunk: Callable[[], Sequence[int | None]]
avals: Sequence[core.AbstractValue] avals: Sequence[core.AbstractValue]
debug_info: api_util.TracingDebugInfo | None debug_info: core.DebugInfo | None
@cached_property @cached_property
def local_devices(self): def local_devices(self):
@ -964,7 +964,7 @@ class UnloadedPmapExecutable:
ordered_effects: list[core.Effect] ordered_effects: list[core.Effect]
keepalive: Sequence[Any] keepalive: Sequence[Any]
host_callbacks: Sequence[Any] host_callbacks: Sequence[Any]
jaxpr_debug_info: core.JaxprDebugInfo jaxpr_debug_info: core.DebugInfo
def build_execute_fun(self): def build_execute_fun(self):
input_indices = [] input_indices = []
@ -1004,7 +1004,7 @@ class UnloadedPmapExecutable:
ordered_effects: list[core.Effect], ordered_effects: list[core.Effect],
host_callbacks: list[Any], host_callbacks: list[Any],
keepalive: Any, keepalive: Any,
jaxpr_debug_info: core.JaxprDebugInfo, jaxpr_debug_info: core.DebugInfo,
platforms: Sequence[str], platforms: Sequence[str],
shape_poly_state: mlir.ShapePolyLoweringState | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None,
compiler_options=None): compiler_options=None):
@ -2127,7 +2127,7 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
class AllArgsInfo(NamedTuple): class AllArgsInfo(NamedTuple):
"""Avals and debug_info for all arguments prior to DCE.""" """Avals and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray] in_avals: Sequence[core.ShapedArray]
debug_info: core.JaxprDebugInfo | None debug_info: core.DebugInfo | None
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
@ -3199,7 +3199,7 @@ def cc_shard_arg(x, sharding, layout):
def check_arg_avals_for_call(ref_avals, arg_avals, def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.JaxprDebugInfo | None = None): jaxpr_debug_info: core.DebugInfo | None = None):
if len(ref_avals) != len(arg_avals): if len(ref_avals) != len(arg_avals):
raise TypeError( raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs " f"Computation compiled for {len(ref_avals)} inputs "
@ -3258,7 +3258,7 @@ def check_array_xla_sharding_layout_match(
args_after_dce, args_after_dce,
in_xla_shardings: Sequence[JSharding], in_xla_shardings: Sequence[JSharding],
in_xla_layouts: Sequence[DeviceLocalLayout], in_xla_layouts: Sequence[DeviceLocalLayout],
jaxpr_debug_info: core.JaxprDebugInfo | None, jaxpr_debug_info: core.DebugInfo | None,
kept_var_idx: set[int]) -> None: kept_var_idx: set[int]) -> None:
from jax._src.array import ArrayImpl from jax._src.array import ArrayImpl
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them. # jaxpr_debug_info.arg_names are before DCE, so need to DCE them.

@ -53,7 +53,7 @@ def _typecheck_param(prim, param, name, msg_required, pred):
def _initial_style_open_jaxpr(fun: Callable, def _initial_style_open_jaxpr(fun: Callable,
in_tree: PyTreeDef, in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue], in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo): debug_info: core.DebugInfo):
wrapped_fun, out_tree = api_util.flatten_fun_nokwargs( wrapped_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug_info), lu.wrap_init(fun, debug_info=debug_info),
in_tree) in_tree)
@ -65,7 +65,7 @@ def _initial_style_open_jaxpr(fun: Callable,
def _initial_style_jaxpr(fun: Callable, def _initial_style_jaxpr(fun: Callable,
in_tree: PyTreeDef, in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue], in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo): debug_info: core.DebugInfo):
jaxpr, consts, out_tree, () = _initial_style_open_jaxpr( jaxpr, consts, out_tree, () = _initial_style_open_jaxpr(
fun, in_tree, in_avals, debug_info) fun, in_tree, in_avals, debug_info)
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
@ -74,7 +74,7 @@ def _initial_style_jaxpr(fun: Callable,
def _initial_style_jaxpr_attrs(fun: Callable, def _initial_style_jaxpr_attrs(fun: Callable,
in_tree: PyTreeDef, in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue], in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo): debug_info: core.DebugInfo):
jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr( jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr(
fun, in_tree, in_avals, debug_info) fun, in_tree, in_avals, debug_info)
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
@ -83,7 +83,7 @@ def _initial_style_jaxpr_attrs(fun: Callable,
def _initial_style_jaxprs_with_common_consts( def _initial_style_jaxprs_with_common_consts(
funs: Sequence[Callable], funs: Sequence[Callable],
in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
debug_infos: Sequence[api_util.TracingDebugInfo]): debug_infos: Sequence[core.DebugInfo]):
# When staging the branches of a conditional into jaxprs, constants are # When staging the branches of a conditional into jaxprs, constants are
# extracted from each branch and converted to jaxpr arguments. To use the # extracted from each branch and converted to jaxpr arguments. To use the
# staged jaxprs as the branches to a conditional *primitive*, we need for # staged jaxprs as the branches to a conditional *primitive*, we need for

@ -134,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands,
if (config.disable_jit.value and core.is_concrete(index)): if (config.disable_jit.value and core.is_concrete(index)):
return branches[int(index)](*operands) return branches[int(index)](*operands)
dbgs = [api_util.tracing_debug_info("switch", branch, operands, {}) dbgs = [api_util.debug_info("switch", branch, operands, {})
for branch in branches] for branch in branches]
ops, ops_tree = tree_flatten(operands) ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops)) ops_avals = tuple(map(core.get_aval, ops))
@ -237,10 +237,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
ops, ops_tree = tree_flatten(operands) ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops)) ops_avals = tuple(map(core.get_aval, ops))
dbg_true_fun = api_util.tracing_debug_info("cond", true_fun, operands, {}) dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {})
if config.mutable_array_checks.value: if config.mutable_array_checks.value:
api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops) api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops)
dbg_false_fun = api_util.tracing_debug_info("cond", false_fun, operands, {}) dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {})
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, (true_fun, false_fun), ops_tree, ops_avals,
[dbg_true_fun, dbg_false_fun]) [dbg_true_fun, dbg_false_fun])

@ -195,7 +195,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
def _create_jaxpr(init): def _create_jaxpr(init):
init_flat = tree_leaves(init) init_flat = tree_leaves(init)
_, in_tree = tree_flatten((init, xs)) _, in_tree = tree_flatten((init, xs))
dbg = api_util.tracing_debug_info("scan", f, (init, xs), {}) dbg = api_util.debug_info("scan", f, (init, xs), {})
carry_avals = tuple(map(core.get_aval, init_flat)) carry_avals = tuple(map(core.get_aval, init_flat))
jaxpr, _, out_tree = _initial_style_jaxpr( jaxpr, _, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, dbg) f, in_tree, carry_avals + x_avals, dbg)

@ -273,7 +273,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
return carry, stacked_y return carry, stacked_y
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
dbg_body = api_util.tracing_debug_info("scan", f, (init, xs), {}) dbg_body = api_util.debug_info("scan", f, (init, xs), {})
if config.mutable_array_checks.value: if config.mutable_array_checks.value:
in_flat, in_tree = tree_flatten((init, xs)) in_flat, in_tree = tree_flatten((init, xs))
@ -1357,10 +1357,10 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
def _create_jaxpr(init_val): def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,)) init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(core.get_aval, init_vals)) init_avals = tuple(_map(core.get_aval, init_vals))
cond_dbg = api_util.tracing_debug_info("while_cond", cond_fun, (init_val,), {}) cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, cond_dbg) cond_fun, in_tree, init_avals, cond_dbg)
body_dbg = api_util.tracing_debug_info("while_body", body_fun, (init_val,), {}) body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
body_fun, in_tree, init_avals, body_dbg) body_fun, in_tree, init_avals, body_dbg)
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:

@ -93,16 +93,16 @@ def custom_root(f: Callable,
""" """
guess_flat, in_args_tree = tree_flatten((initial_guess,)) guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(core.get_aval, guess_flat)) guess_avals = tuple(_map(core.get_aval, guess_flat))
f_debug = api_util.tracing_debug_info("custom_root", f, (initial_guess,), {}) f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {})
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals, f_debug) f, in_args_tree, guess_avals, f_debug)
in_tree, = treedef_children(in_args_tree) in_tree, = treedef_children(in_args_tree)
_check_tree("f", "initial_guess", out_tree, in_tree, False) _check_tree("f", "initial_guess", out_tree, in_tree, False)
solve_debug = api_util.tracing_debug_info("custom_root solve", solve, solve_debug = api_util.debug_info("custom_root solve", solve,
(f, initial_guess), {}, (f, initial_guess), {},
static_argnums=(0,)) static_argnums=(0,))
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
partial(solve, f), in_args_tree, guess_avals, solve_debug) partial(solve, f), in_args_tree, guess_avals, solve_debug)
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
@ -111,10 +111,10 @@ def custom_root(f: Callable,
unchecked_zeros, f_jvp = api.linearize(f, x) unchecked_zeros, f_jvp = api.linearize(f, x)
return tangent_solve(f_jvp, b) return tangent_solve(f_jvp, b)
tangent_solve_debug = api_util.tracing_debug_info("custom_root tangent_solve", tangent_solve_debug = api_util.debug_info("custom_root tangent_solve",
tangent_solve, tangent_solve,
(f, initial_guess), {}, (f, initial_guess), {},
static_argnums=(0,)) static_argnums=(0,))
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2, linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2,
tangent_solve_debug) tangent_solve_debug)
@ -265,17 +265,17 @@ def custom_linear_solve(
return f_aux if has_aux else f return f_aux if has_aux else f
matvec_debug = api_util.tracing_debug_info("custom_linear_solve", matvec_debug = api_util.debug_info("custom_linear_solve",
matvec, (b,), {}) matvec, (b,), {})
# no auxiliary data assumed for matvec # no auxiliary data assumed for matvec
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals, _shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
matvec_debug) matvec_debug)
_check_tree("matvec", "b", out_tree, tree, False) _check_tree("matvec", "b", out_tree, tree, False)
solve_debug = api_util.tracing_debug_info("custom_linear_solve solve", solve_debug = api_util.debug_info("custom_linear_solve solve",
solve, (matvec, b), {}, solve, (matvec, b), {},
static_argnums=(0,)) static_argnums=(0,))
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
solve_debug) solve_debug)
@ -285,7 +285,7 @@ def custom_linear_solve(
vecmat_jaxpr = tr_solve_jaxpr = None vecmat_jaxpr = tr_solve_jaxpr = None
vecmat_consts = tr_solve_consts = [] vecmat_consts = tr_solve_consts = []
else: else:
transpose_solve_debug = api_util.tracing_debug_info( transpose_solve_debug = api_util.debug_info(
"custom_linear_solve transpose_solve", transpose_solve, "custom_linear_solve transpose_solve", transpose_solve,
(matvec, b), {}, static_argnums=(0,)) (matvec, b), {}, static_argnums=(0,))
if symmetric: if symmetric:

@ -745,7 +745,7 @@ def _trace_composite_to_jaxpr(fun: Callable,
in_tree: tree_util.PyTreeDef, in_tree: tree_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue], in_avals: Sequence[core.AbstractValue],
name: str, name: str,
debug_info: api_util.TracingDebugInfo): debug_info: core.DebugInfo):
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info) jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
if any(isinstance(c, core.Tracer) for c in consts): if any(isinstance(c, core.Tracer) for c in consts):
@ -822,8 +822,8 @@ def composite(
""" """
@functools.wraps(decomposition) @functools.wraps(decomposition)
def _decorator(*args, **kwargs): def _decorator(*args, **kwargs):
debug_info = api_util.tracing_debug_info("composite", decomposition, debug_info = api_util.debug_info("composite", decomposition,
args, kwargs) args, kwargs)
flat_args, in_tree = tree_util.tree_flatten(args) flat_args, in_tree = tree_util.tree_flatten(args)
in_avals = tuple(core.get_aval(x) for x in flat_args) in_avals = tuple(core.get_aval(x) for x in flat_args)
closed_jaxpr, out_tree = _trace_composite_to_jaxpr( closed_jaxpr, out_tree = _trace_composite_to_jaxpr(

@ -156,7 +156,7 @@ class WrappedFun:
f_transformed: Callable, f_transformed: Callable,
transforms, transforms,
stores: tuple[Store | EqualStore | None, ...], params, in_type, stores: tuple[Store | EqualStore | None, ...], params, in_type,
debug_info: TracingDebugInfo | None): debug_info: DebugInfo | None):
self.f = f self.f = f
self.f_transformed = f_transformed self.f_transformed = f_transformed
self.transforms = transforms self.transforms = transforms
@ -253,12 +253,10 @@ def fun_name(f):
except: except:
return str(f) return str(f)
class TracingDebugInfo(NamedTuple): class DebugInfo(NamedTuple):
"""Tracing-time debugging info about a func and its arguments. """Debugging info about a func, its arguments, and results."""
Formed just before staging to a jaxpr and read in trace-time error messages.
"""
traced_for: str # e.g. 'jit', 'scan', etc traced_for: str # e.g. 'jit', 'scan', etc
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have # e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
# no source location information. The first word is always the function name, # no source location information. The first word is always the function name,
# which may be '<unknown>'. # which may be '<unknown>'.
@ -270,23 +268,25 @@ class TracingDebugInfo(NamedTuple):
# e.g., tangent args in jax.jvp. # e.g., tangent args in jax.jvp.
arg_names: tuple[str | None, ...] arg_names: tuple[str | None, ...]
# The result paths are not available while we are tracing the function,
# instead we keep a thunk. Once we are done tracing, we use
# `self.resolve_result_paths()` to execute the thunk and replace the
# actual result paths.
# e.g. ('[0]', '[1]', ...) # e.g. ('[0]', '[1]', ...)
result_paths_thunk: Callable[[], tuple[str, ...]] | None result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
@classmethod def add_result_paths(self,
def from_jaxpr(cls, jaxpr: core.ClosedJaxpr) -> TracingDebugInfo | None: result_paths_thunk: Callable[[], tuple[str, ...]]
jaxpr_dbg = jaxpr.jaxpr._debug_info ) -> DebugInfo:
if jaxpr_dbg is None: return None assert self.result_paths is None
return TracingDebugInfo(jaxpr_dbg.traced_for, return self._replace(result_paths=HashableFunction(result_paths_thunk,
jaxpr_dbg.func_src_info, closure=()))
jaxpr_dbg.arg_names,
lambda: jaxpr_dbg.result_paths)
def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]] def resolve_result_paths(self) -> DebugInfo:
) -> TracingDebugInfo: """Return a debug info with resolved result paths."""
assert self.result_paths_thunk is None if callable(self.result_paths):
return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk, return self._replace(result_paths=tuple(self.result_paths()))
closure=())) return self
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]: def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check.""" """Get the arg_names with a safety check."""
@ -296,9 +296,18 @@ class TracingDebugInfo(NamedTuple):
# TODO(necula): this should not happen # TODO(necula): this should not happen
return (None,) * expected return (None,) * expected
def safe_result_paths(self, expected: int) -> tuple[str, ...]:
"""Get the result paths with a safety check."""
assert not callable(self.result_paths), self
if self.result_paths is not None and len(self.result_paths) == expected:
return self.result_paths
else:
# TODO(necula): this should not happen
return ("",) * expected
def wrap_init(f: Callable, params=None, *, def wrap_init(f: Callable, params=None, *,
debug_info: TracingDebugInfo | None = None) -> WrappedFun: debug_info: DebugInfo | None = None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation.""" """Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params_dict = {} if params is None else params params_dict = {} if params is None else params
params = () if params is None else tuple(sorted(params.items())) params = () if params is None else tuple(sorted(params.items()))
@ -341,7 +350,7 @@ def _check_input_type(in_type: core.InputType) -> None:
provided[d.val] = True provided[d.val] = True
assert all(provided) assert all(provided)
def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None def add_debug_info(f: WrappedFun, debug_info: DebugInfo | None
) -> WrappedFun: ) -> WrappedFun:
"""Produce a new WrappedFun with debug_info attached.""" """Produce a new WrappedFun with debug_info attached."""
assert f.debug_info is None assert f.debug_info is None

@ -413,9 +413,9 @@ class BlockSpec:
fake_index_map_args, fake_index_map_kwargs = \ fake_index_map_args, fake_index_map_kwargs = \
index_map_tree.unflatten([False] * index_map_tree.num_leaves) index_map_tree.unflatten([False] * index_map_tree.num_leaves)
debug = api_util.tracing_debug_info("pallas_call index_map", debug = api_util.debug_info("pallas_call index_map",
index_map_func, fake_index_map_args, index_map_func, fake_index_map_args,
fake_index_map_kwargs) fake_index_map_kwargs)
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree) lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
index_map_src_info = NameAndSrcInfo.from_pallas_call( index_map_src_info = NameAndSrcInfo.from_pallas_call(

@ -1100,8 +1100,8 @@ def pallas_call_checkify_rule(error: checkify.Error,
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree) lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
debug = api_util.tracing_debug_info("checkify_pallas", checked_kernel_fn, debug = api_util.debug_info("checkify_pallas", checked_kernel_fn,
retrace_in_avals, {}) retrace_in_avals, {})
with pallas_core.tracing_grid_env(grid_mapping.grid, ()): with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
wrapped_kernel_with_err, jaxpr_flat_avals, debug) wrapped_kernel_with_err, jaxpr_flat_avals, debug)
@ -1167,7 +1167,7 @@ def _trace_kernel_to_jaxpr(
wrapped_kernel_fun, kernel_in_transforms wrapped_kernel_fun, kernel_in_transforms
) )
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals) fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
debug = api_util.tracing_debug_info("pallas_call", fun, fake_kernel_args, {}) debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
with grid_mapping.trace_env(): with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug) kernel_avals, debug)
@ -1568,7 +1568,7 @@ def pallas_call(
kernel_fun_sig = api_util.fun_signature(kernel) kernel_fun_sig = api_util.fun_signature(kernel)
arg_names = None arg_names = None
if kernel_fun_sig: if kernel_fun_sig:
kernel_debug_info = api_util.tracing_debug_info( kernel_debug_info = api_util.debug_info(
"pallas_call kernel", "pallas_call kernel",
kernel, kernel,
[1] * len(kernel_fun_sig.parameters), {}) [1] * len(kernel_fun_sig.parameters), {})

@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
from jax._src.api_util import ( from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, check_callable, resolve_argnums, donation_vector, check_callable, resolve_argnums,
argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info, argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info,
hoist_obj_attrs, _check_no_aliased_ref_args, hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs) _check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
@ -548,7 +548,7 @@ def _infer_params_impl(
ji: PjitInfo, ji: PjitInfo,
pjit_mesh: mesh_lib.Mesh | None, pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None, resource_env: mesh_lib.ResourceEnv | None,
dbg: lu.TracingDebugInfo, dbg: core.DebugInfo,
args: tuple[Any, ...], args: tuple[Any, ...],
kwargs: dict[str, Any], kwargs: dict[str, Any],
in_avals: tuple[core.AbstractValue, ...] | None, in_avals: tuple[core.AbstractValue, ...] | None,
@ -733,7 +733,7 @@ def _infer_params(
'Using `with mesh:` context manager and `jax.sharding.use_mesh`' 'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
' together is not allowed.') ' together is not allowed.')
dbg = tracing_debug_info( dbg = debug_info(
'jit', fun, args, kwargs, static_argnums=ji.static_argnums, 'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature) signature=ji.fun_signature)
@ -756,7 +756,7 @@ def _infer_params(
entry.pjit_params = p entry.pjit_params = p
return entry.pjit_params, entry.pjit_params.consts + dynargs return entry.pjit_params, entry.pjit_params.consts + dynargs
def _infer_input_type(fun: Callable, dbg: lu.TracingDebugInfo | None, def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None,
explicit_args) -> tuple[core.AbstractValue, ...]: explicit_args) -> tuple[core.AbstractValue, ...]:
avals = [] avals = []
try: try:
@ -1302,7 +1302,7 @@ def _create_pjit_jaxpr(
fun: lu.WrappedFun, fun: lu.WrappedFun,
in_type: core.InputType | Sequence[core.AbstractValue], in_type: core.InputType | Sequence[core.AbstractValue],
attr_data: int, attr_data: int,
debug_info: lu.TracingDebugInfo, debug_info: core.DebugInfo,
result_paths: Callable, result_paths: Callable,
ignored_inline: IgnoreKey ignored_inline: IgnoreKey
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
@ -1346,7 +1346,7 @@ def _create_pjit_jaxpr(
def _check_and_canonicalize_out_shardings( def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_avals, out_layouts_leaves, out_tree, out_avals,
debug_info: core.JaxprDebugInfo | None, debug_info: core.DebugInfo | None,
device_or_backend_set): device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)): if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):

@ -989,7 +989,7 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
def initial_style_jaxpr( def initial_style_jaxpr(
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
dbg: api_util.TracingDebugInfo, dbg: core.DebugInfo,
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]: ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg) return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg)
@ -997,7 +997,7 @@ def initial_style_jaxpr(
def _initial_style_jaxpr(fun: Callable, def _initial_style_jaxpr(fun: Callable,
in_tree: api_util.PyTreeDef, in_tree: api_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue], in_avals: Sequence[core.AbstractValue],
debug: api_util.TracingDebugInfo): debug: core.DebugInfo):
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
tree_util.treedef_tuple((in_tree,))) tree_util.treedef_tuple((in_tree,)))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
@ -1007,7 +1007,7 @@ def _initial_style_jaxpr(fun: Callable,
T = TypeVar('T') T = TypeVar('T')
def run_state(f: Callable[..., None]) -> Callable[[T], T]: def run_state(f: Callable[..., None]) -> Callable[[T], T]:
def wrapped(args): def wrapped(args):
dbg = api_util.tracing_debug_info("run_state", f, (args,), {}) dbg = api_util.debug_info("run_state", f, (args,), {})
flat_args, in_tree = tree_util.tree_flatten(args) flat_args, in_tree = tree_util.tree_flatten(args)
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
# There may be some uninitialized values here in ref_args. # There may be some uninitialized values here in ref_args.
@ -1027,7 +1027,7 @@ def run_state(f: Callable[..., None]) -> Callable[[T], T]:
def run_state_reference(f: Callable[..., None]): def run_state_reference(f: Callable[..., None]):
def wrapped(args): def wrapped(args):
dbg = api_util.tracing_debug_info("run_state", f, (args,), {}) dbg = api_util.debug_info("run_state", f, (args,), {})
flat_args, in_tree = tree_util.tree_flatten(args) flat_args, in_tree = tree_util.tree_flatten(args)
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg) jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg)

@ -20,13 +20,13 @@ from jax._src.core import (
AbstractValue as AbstractValue, AbstractValue as AbstractValue,
Atom as Atom, Atom as Atom,
CallPrimitive as CallPrimitive, CallPrimitive as CallPrimitive,
DebugInfo as DebugInfo,
DShapedArray as DShapedArray, DShapedArray as DShapedArray,
DropVar as DropVar, DropVar as DropVar,
Effect as Effect, Effect as Effect,
Effects as Effects, Effects as Effects,
get_opaque_trace_state as get_opaque_trace_state, get_opaque_trace_state as get_opaque_trace_state,
InconclusiveDimensionOperation as InconclusiveDimensionOperation, InconclusiveDimensionOperation as InconclusiveDimensionOperation,
JaxprDebugInfo as JaxprDebugInfo,
JaxprPpContext as JaxprPpContext, JaxprPpContext as JaxprPpContext,
JaxprPpSettings as JaxprPpSettings, JaxprPpSettings as JaxprPpSettings,
JaxprTypeError as JaxprTypeError, JaxprTypeError as JaxprTypeError,

@ -69,13 +69,13 @@ def _collect_jaxprs(jaxpr: core.Jaxpr,
return acc return acc
def _debug_info_to_string(dbg: api_util.TracingDebugInfo | core.JaxprDebugInfo | None) -> list[str]: def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]:
if dbg is None: return "None" if dbg is None: return "None"
# Strip the absolute path and the line number but check that it references # Strip the absolute path and the line number but check that it references
# this file (to catch errors when the source info points in JAX internals) # this file (to catch errors when the source info points in JAX internals)
fun_src_info = re.sub(r"^(\S+)( at .*/debug_info_test.py:.*)?", "\\1", dbg.func_src_info) fun_src_info = re.sub(r"^(\S+)( at .*/debug_info_test.py:.*)?", "\\1", dbg.func_src_info)
res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={','.join(dbg.arg_names)}" res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={','.join(dbg.arg_names)}"
if isinstance(dbg, core.JaxprDebugInfo): if isinstance(dbg.result_paths, tuple):
if dbg.result_paths: if dbg.result_paths:
res += f", result_paths={','.join(dbg.result_paths)}" res += f", result_paths={','.join(dbg.result_paths)}"
else: else:
@ -221,24 +221,24 @@ class DebugInfoTest(jtu.JaxTestCase):
def my_f(x, y, z, w): def my_f(x, y, z, w):
pass pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) self.assertEqual(dbg.arg_names, ("x", "y", "z", "w"))
self.assertIsNone(dbg.result_paths_thunk) self.assertIsNone(dbg.result_paths)
def test_debug_info_arg_passed_as_kwarg(self): def test_debug_info_arg_passed_as_kwarg(self):
def my_f(x, y, z): def my_f(x, y, z):
pass pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3)) dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3))
self.assertEqual(dbg.arg_names, ("x", "y", "z")) self.assertEqual(dbg.arg_names, ("x", "y", "z"))
def test_debug_info_pytrees(self): def test_debug_info_pytrees(self):
def my_f(x_tree, *, y_tree): def my_f(x_tree, *, y_tree):
pass pass
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),), dbg = api_util.debug_info("jit", my_f, ((1, 2),),
dict(y_tree=dict(z=3, w=4))) dict(y_tree=dict(z=3, w=4)))
self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]", self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]",
"y_tree['w']", "y_tree['z']")) "y_tree['w']", "y_tree['z']"))
@ -246,43 +246,43 @@ class DebugInfoTest(jtu.JaxTestCase):
def my_f(x, y, *, z, w): def my_f(x, y, *, z, w):
pass pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4), dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
static_argnums=(1,), static_argnums=(1,),
static_argnames=("w",)) static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x", "z")) self.assertEqual(dbg.arg_names, ("x", "z"))
def test_debug_info_with_pytrees_and_statics(self): def test_debug_info_with_pytrees_and_statics(self):
def my_f(x, y, *, z, w): def my_f(x, y, *, z, w):
pass pass
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)), dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)),
dict(z=(3, 4), w=(5, 6)), dict(z=(3, 4), w=(5, 6)),
static_argnums=(1,), static_argnums=(1,),
static_argnames=("w",)) static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]"))
def test_debug_info_too_many_args(self): def test_debug_info_too_many_args(self):
def my_f(x): def my_f(x):
pass pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3)) dbg = api_util.debug_info("jit", my_f, (1, 2, 3), dict(z=3))
self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']"))
def test_debug_info_no_source_info_built_in(self): def test_debug_info_no_source_info_built_in(self):
# built-in function "int" does not have an inspect.Signature # built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", max, (1,), {}) dbg = api_util.debug_info("jit", max, (1,), {})
self.assertEqual(dbg.func_src_info, "max") self.assertEqual(dbg.func_src_info, "max")
self.assertEqual(dbg.arg_names, ("args[0]",)) self.assertEqual(dbg.arg_names, ("args[0]",))
def test_debug_info_lambda(self): def test_debug_info_lambda(self):
# built-in function "int" does not have an inspect.Signature # built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {}) dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {})
self.assertRegex(dbg.func_src_info, r"^<lambda> at .*debug_info_test.py:\d+") self.assertRegex(dbg.func_src_info, r"^<lambda> at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("my_arg",)) self.assertEqual(dbg.arg_names, ("my_arg",))
def test_debug_info_no_source_info_not_callable(self): def test_debug_info_no_source_info_not_callable(self):
# built-in function "int" does not have an inspect.Signature # built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", False, (1,), {}) dbg = api_util.debug_info("jit", False, (1,), {})
self.assertEqual(dbg.func_src_info, "<unknown>") self.assertEqual(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("args[0]",)) self.assertEqual(dbg.arg_names, ("args[0]",))
@ -293,7 +293,7 @@ class DebugInfoTest(jtu.JaxTestCase):
def __call__(self, y): def __call__(self, y):
return self.x + y return self.x + y
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {}) dbg = api_util.debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>") self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",)) self.assertEqual(dbg.arg_names, ("y",))
@ -307,7 +307,7 @@ class DebugInfoTest(jtu.JaxTestCase):
def __repr__(self): def __repr__(self):
raise NotImplementedError raise NotImplementedError
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {}) dbg = api_util.debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>") self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",)) self.assertEqual(dbg.arg_names, ("y",))