[better_errors] Merge the JaxprDebugInfo and TracingDebugInfo into core.DebugInfo

Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
This commit is contained in:
George Necula 2025-01-31 22:23:20 +02:00
parent de48ce2a4c
commit c70de6deed
24 changed files with 159 additions and 176 deletions

View File

@ -323,7 +323,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
debug = api_util.tracing_debug_info(
debug = api_util.debug_info(
"checkpoint / remat", fun,
args, kwargs, static_argnums=static_argnums)
fun_, args = _remat_static_argnums(fun, static_argnums, args)
@ -418,7 +418,7 @@ _dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached)
def _trace_to_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug: lu.TracingDebugInfo
debug: core.DebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
try:
@ -447,7 +447,7 @@ def saved_residuals(f: Callable,
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs)
debug_info = api_util.debug_info("saved_residuals", f, args, kwargs)
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
return_shape=True)(*in_leaves)
assert isinstance(out, tuple)

View File

@ -57,11 +57,11 @@ from jax._src import pjit
from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray
from jax._src.api_util import (
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
result_paths, flat_out_axes)
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
@ -452,7 +452,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0,
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
f"{max_argnum + 1} positional arguments to be passed by the caller, "
f"but got only {len(args)} positional arguments.")
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs)
dbg = debug_info('value_and_grad', fun, args, kwargs)
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
f_partial, dyn_args = argnums_partial(f, argnums, args,
@ -1426,7 +1426,7 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
dbg = tracing_debug_info(
dbg = debug_info(
"pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple)

View File

@ -31,7 +31,6 @@ from jax._src.tree_util import (
prefix_errors)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable, safe_zip)
from jax._src import traceback_util
@ -582,7 +581,7 @@ def api_hook(fun, tag: str):
return fun
def tracing_debug_info(
def debug_info(
traced_for: str,
fun: Callable,
args: Sequence[Any],
@ -594,14 +593,14 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
) -> core.DebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
def fun_signature(fun: Callable) -> inspect.Signature | None:
@ -619,7 +618,7 @@ _fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str:
# See TracingDebugInfo.fun_src_info
# See DebugInfo.fun_src_info
res = getattr(fun, "__fun_sourceinfo__", None)
if res is not None: return res
while isinstance(fun, partial):
@ -684,20 +683,19 @@ def result_paths(_fun, _store, *args, **kwargs):
# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
debug: core.DebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
if debug is None:
return jaxpr
# TODO(necula): re-enable this safety check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is None:
result_paths = trace_debug.result_paths_thunk() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, tuple(result_paths)) # type: ignore
return jaxpr.replace(debug_info=debug_info)
if result_paths is not None:
debug = debug._replace(result_paths=tuple(result_paths))
else:
debug = debug.resolve_result_paths()
return jaxpr.replace(debug_info=debug)
def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []

View File

@ -1202,7 +1202,7 @@ def checkify(f: Callable[..., Out],
in_tree = jtu.tree_structure(((), {}))
closed_f = lambda: f(*args, **kwargs)
# stage:
debug = api_util.tracing_debug_info("checkify", f, args, kwargs)
debug = api_util.debug_info("checkify", f, args, kwargs)
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
debug_info=debug),
in_tree)

View File

@ -82,31 +82,7 @@ EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
# TODO(necula): make this an extension of TracingDebugInfo
class JaxprDebugInfo(NamedTuple):
# An extension of lu.TracingDebugInfo; see comments there
traced_for: str
func_src_info: str
arg_names: tuple[str | None, ...]
# This is formed after tracing, when we have concrete `result_paths`
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
return self.arg_names
else:
# TODO(necula): this should not happen
return (None,) * expected
def safe_result_paths(self, expected: int) -> tuple[str | None, ...]:
"""Get the result_paths with a safety check."""
if len(self.result_paths) == expected:
return self.result_paths
else:
# TODO(necula): this should not happen
return ("",) * expected
DebugInfo = lu.DebugInfo
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
@ -117,7 +93,7 @@ class Jaxpr:
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: JaxprDebugInfo | None
_debug_info: DebugInfo | None
@property
def constvars(self) -> list[Var]:
@ -140,13 +116,13 @@ class Jaxpr:
return self._effects
@property
def debug_info(self) -> JaxprDebugInfo | None:
def debug_info(self) -> DebugInfo | None:
return self._debug_info
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
debug_info: JaxprDebugInfo | None = None):
debug_info: DebugInfo | None = None):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
@ -157,14 +133,14 @@ class Jaxpr:
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
debug_info: optional JaxprDebugInfo.
debug_info: optional DebugInfo.
"""
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
self._debug_info = debug_info and debug_info.resolve_result_paths()
# TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)

View File

@ -147,7 +147,7 @@ class custom_vmap:
raise AttributeError(
f"No batching rule defined for custom_vmap function {fun_name} "
"using def_vmap.")
debug = api_util.tracing_debug_info("custom_vmap", self.fun, args, {})
debug = api_util.debug_info("custom_vmap", self.fun, args, {})
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(self.fun, debug_info=debug),

View File

@ -127,12 +127,12 @@ class custom_dce:
"def_dce."
)
rule_name = util.fun_name(self.dce_rule)
debug = api_util.tracing_debug_info("custom_dce", self.fun,
args, {},
static_argnums=self.static_argnums)
debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule,
args, {},
static_argnums=self.static_argnums)
debug = api_util.debug_info("custom_dce", self.fun,
args, {},
static_argnums=self.static_argnums)
debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule,
args, {},
static_argnums=self.static_argnums)
args = api_util.resolve_kwargs(self.fun, args, kwargs)
if self.static_argnums:
static_argnums = set(self.static_argnums)

View File

@ -468,9 +468,9 @@ class custom_partitioning:
def __call__(self, *args, **kwargs):
args = _resolve_kwargs(self.fun, args, kwargs)
debug = api_util.tracing_debug_info("custom_partitioning", self.fun,
args, kwargs,
static_argnums=self.static_argnums)
debug = api_util.debug_info("custom_partitioning", self.fun,
args, kwargs,
static_argnums=self.static_argnums)
if self.static_argnums:
static_argnums = set(self.static_argnums)
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))

View File

@ -147,7 +147,7 @@ def _linearize_jaxpr(
jaxpr: core.ClosedJaxpr,
nonzeros: tuple[bool, ...]
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]:
dbg = lu.TracingDebugInfo.from_jaxpr(jaxpr)
dbg = jaxpr.jaxpr.debug_info
primal_trace = pe.DynamicJaxprTrace(dbg)
tangent_trace = pe.DynamicJaxprTrace(dbg)
lin_trace = LinearizeTrace(primal_trace, tangent_trace)

View File

@ -42,7 +42,6 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src import tree_util
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@ -932,7 +931,7 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
in_unknowns: Sequence[bool],
instantiate: bool | Sequence[bool]):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
debug_info=lu.TracingDebugInfo.from_jaxpr(jaxpr))
debug_info=jaxpr.jaxpr.debug_info)
cell = []
def fun(*known_vals_in):
@ -1334,10 +1333,11 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names,
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
used_outputs) if b))
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr
@ -1422,10 +1422,12 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
eqns = new_eqns[::-1]
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
tuple(v for v, b in zip(jaxpr.debug_info.safe_arg_names(len(used_inputs)),
used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
used_outputs) if b))
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
@ -1623,9 +1625,9 @@ class JaxprStackFrame:
attrs_tracked: list[tuple[Any, str]]
attrs_inits: list
attrs_vars: list[Var]
debug_info: lu.TracingDebugInfo | None
debug_info: core.DebugInfo | None
def __init__(self, debug_info: lu.TracingDebugInfo | None):
def __init__(self, debug_info: core.DebugInfo | None):
self.gensym = core.gensym()
self.tracer_to_var = {}
self.constid_to_tracer = {}
@ -1809,7 +1811,7 @@ def _inline_literals(
class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",)
def __init__(self, debug_info: lu.TracingDebugInfo | None):
def __init__(self, debug_info: core.DebugInfo | None):
self.frame = JaxprStackFrame(debug_info)
def invalidate(self):
@ -2114,7 +2116,7 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
def trace_to_jaxpr_dynamic(
fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: lu.TracingDebugInfo | None = None,
debug_info: core.DebugInfo | None = None,
*,
keep_inputs: list[bool] | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
@ -2137,7 +2139,7 @@ def trace_to_jaxpr_dynamic(
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
def _check_no_returned_refs(
dbg: lu.TracingDebugInfo | None,
dbg: core.DebugInfo | None,
out_tracers: Sequence[DynamicJaxprTracer]
) -> None:
if not config.mutable_array_checks.value: return
@ -2148,10 +2150,8 @@ def _check_no_returned_refs(
raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.")
loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore
if (dbg.result_paths_thunk and
(ls := dbg.result_paths_thunk()) and
ls[i]) else '')
result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
loc = f' at output tree path {result_paths[i]}'
frame = t._trace.frame
v = frame.tracer_to_var.get(id(t))
eqn = next((e for e in frame.eqns if v in e.outvars), None)
@ -2172,7 +2172,7 @@ def _check_no_returned_refs(
@profiler.annotate_function
def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None
fun: lu.WrappedFun, debug_info: core.DebugInfo | None = None
) -> tuple[Jaxpr, OutputType, list[Any]]:
trace = DynamicJaxprTrace(debug_info)

View File

@ -652,7 +652,7 @@ class ParallelCallableInfo:
in_axes: Iterable[int | None]
out_axes_thunk: Callable[[], Sequence[int | None]]
avals: Sequence[core.AbstractValue]
debug_info: api_util.TracingDebugInfo | None
debug_info: core.DebugInfo | None
@cached_property
def local_devices(self):
@ -964,7 +964,7 @@ class UnloadedPmapExecutable:
ordered_effects: list[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
jaxpr_debug_info: core.JaxprDebugInfo
jaxpr_debug_info: core.DebugInfo
def build_execute_fun(self):
input_indices = []
@ -1004,7 +1004,7 @@ class UnloadedPmapExecutable:
ordered_effects: list[core.Effect],
host_callbacks: list[Any],
keepalive: Any,
jaxpr_debug_info: core.JaxprDebugInfo,
jaxpr_debug_info: core.DebugInfo,
platforms: Sequence[str],
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
compiler_options=None):
@ -2127,7 +2127,7 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
class AllArgsInfo(NamedTuple):
"""Avals and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray]
debug_info: core.JaxprDebugInfo | None
debug_info: core.DebugInfo | None
@lru_cache(maxsize=2048)
@ -3199,7 +3199,7 @@ def cc_shard_arg(x, sharding, layout):
def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
jaxpr_debug_info: core.DebugInfo | None = None):
if len(ref_avals) != len(arg_avals):
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
@ -3258,7 +3258,7 @@ def check_array_xla_sharding_layout_match(
args_after_dce,
in_xla_shardings: Sequence[JSharding],
in_xla_layouts: Sequence[DeviceLocalLayout],
jaxpr_debug_info: core.JaxprDebugInfo | None,
jaxpr_debug_info: core.DebugInfo | None,
kept_var_idx: set[int]) -> None:
from jax._src.array import ArrayImpl
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.

View File

@ -53,7 +53,7 @@ def _typecheck_param(prim, param, name, msg_required, pred):
def _initial_style_open_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo):
debug_info: core.DebugInfo):
wrapped_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug_info),
in_tree)
@ -65,7 +65,7 @@ def _initial_style_open_jaxpr(fun: Callable,
def _initial_style_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo):
debug_info: core.DebugInfo):
jaxpr, consts, out_tree, () = _initial_style_open_jaxpr(
fun, in_tree, in_avals, debug_info)
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
@ -74,7 +74,7 @@ def _initial_style_jaxpr(fun: Callable,
def _initial_style_jaxpr_attrs(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo):
debug_info: core.DebugInfo):
jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr(
fun, in_tree, in_avals, debug_info)
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
@ -83,7 +83,7 @@ def _initial_style_jaxpr_attrs(fun: Callable,
def _initial_style_jaxprs_with_common_consts(
funs: Sequence[Callable],
in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
debug_infos: Sequence[api_util.TracingDebugInfo]):
debug_infos: Sequence[core.DebugInfo]):
# When staging the branches of a conditional into jaxprs, constants are
# extracted from each branch and converted to jaxpr arguments. To use the
# staged jaxprs as the branches to a conditional *primitive*, we need for

View File

@ -134,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands,
if (config.disable_jit.value and core.is_concrete(index)):
return branches[int(index)](*operands)
dbgs = [api_util.tracing_debug_info("switch", branch, operands, {})
dbgs = [api_util.debug_info("switch", branch, operands, {})
for branch in branches]
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))
@ -237,10 +237,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))
dbg_true_fun = api_util.tracing_debug_info("cond", true_fun, operands, {})
dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {})
if config.mutable_array_checks.value:
api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops)
dbg_false_fun = api_util.tracing_debug_info("cond", false_fun, operands, {})
dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {})
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals,
[dbg_true_fun, dbg_false_fun])

View File

@ -195,7 +195,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
def _create_jaxpr(init):
init_flat = tree_leaves(init)
_, in_tree = tree_flatten((init, xs))
dbg = api_util.tracing_debug_info("scan", f, (init, xs), {})
dbg = api_util.debug_info("scan", f, (init, xs), {})
carry_avals = tuple(map(core.get_aval, init_flat))
jaxpr, _, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, dbg)

View File

@ -273,7 +273,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
return carry, stacked_y
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
dbg_body = api_util.tracing_debug_info("scan", f, (init, xs), {})
dbg_body = api_util.debug_info("scan", f, (init, xs), {})
if config.mutable_array_checks.value:
in_flat, in_tree = tree_flatten((init, xs))
@ -1357,10 +1357,10 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(core.get_aval, init_vals))
cond_dbg = api_util.tracing_debug_info("while_cond", cond_fun, (init_val,), {})
cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, cond_dbg)
body_dbg = api_util.tracing_debug_info("while_body", body_fun, (init_val,), {})
body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
body_fun, in_tree, init_avals, body_dbg)
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:

View File

@ -93,16 +93,16 @@ def custom_root(f: Callable,
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(core.get_aval, guess_flat))
f_debug = api_util.tracing_debug_info("custom_root", f, (initial_guess,), {})
f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {})
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals, f_debug)
in_tree, = treedef_children(in_args_tree)
_check_tree("f", "initial_guess", out_tree, in_tree, False)
solve_debug = api_util.tracing_debug_info("custom_root solve", solve,
(f, initial_guess), {},
static_argnums=(0,))
solve_debug = api_util.debug_info("custom_root solve", solve,
(f, initial_guess), {},
static_argnums=(0,))
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
partial(solve, f), in_args_tree, guess_avals, solve_debug)
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
@ -111,10 +111,10 @@ def custom_root(f: Callable,
unchecked_zeros, f_jvp = api.linearize(f, x)
return tangent_solve(f_jvp, b)
tangent_solve_debug = api_util.tracing_debug_info("custom_root tangent_solve",
tangent_solve,
(f, initial_guess), {},
static_argnums=(0,))
tangent_solve_debug = api_util.debug_info("custom_root tangent_solve",
tangent_solve,
(f, initial_guess), {},
static_argnums=(0,))
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2,
tangent_solve_debug)
@ -265,17 +265,17 @@ def custom_linear_solve(
return f_aux if has_aux else f
matvec_debug = api_util.tracing_debug_info("custom_linear_solve",
matvec, (b,), {})
matvec_debug = api_util.debug_info("custom_linear_solve",
matvec, (b,), {})
# no auxiliary data assumed for matvec
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
matvec_debug)
_check_tree("matvec", "b", out_tree, tree, False)
solve_debug = api_util.tracing_debug_info("custom_linear_solve solve",
solve, (matvec, b), {},
static_argnums=(0,))
solve_debug = api_util.debug_info("custom_linear_solve solve",
solve, (matvec, b), {},
static_argnums=(0,))
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
solve_debug)
@ -285,7 +285,7 @@ def custom_linear_solve(
vecmat_jaxpr = tr_solve_jaxpr = None
vecmat_consts = tr_solve_consts = []
else:
transpose_solve_debug = api_util.tracing_debug_info(
transpose_solve_debug = api_util.debug_info(
"custom_linear_solve transpose_solve", transpose_solve,
(matvec, b), {}, static_argnums=(0,))
if symmetric:

View File

@ -745,7 +745,7 @@ def _trace_composite_to_jaxpr(fun: Callable,
in_tree: tree_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
name: str,
debug_info: api_util.TracingDebugInfo):
debug_info: core.DebugInfo):
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
if any(isinstance(c, core.Tracer) for c in consts):
@ -822,8 +822,8 @@ def composite(
"""
@functools.wraps(decomposition)
def _decorator(*args, **kwargs):
debug_info = api_util.tracing_debug_info("composite", decomposition,
args, kwargs)
debug_info = api_util.debug_info("composite", decomposition,
args, kwargs)
flat_args, in_tree = tree_util.tree_flatten(args)
in_avals = tuple(core.get_aval(x) for x in flat_args)
closed_jaxpr, out_tree = _trace_composite_to_jaxpr(

View File

@ -156,7 +156,7 @@ class WrappedFun:
f_transformed: Callable,
transforms,
stores: tuple[Store | EqualStore | None, ...], params, in_type,
debug_info: TracingDebugInfo | None):
debug_info: DebugInfo | None):
self.f = f
self.f_transformed = f_transformed
self.transforms = transforms
@ -253,12 +253,10 @@ def fun_name(f):
except:
return str(f)
class TracingDebugInfo(NamedTuple):
"""Tracing-time debugging info about a func and its arguments.
Formed just before staging to a jaxpr and read in trace-time error messages.
"""
class DebugInfo(NamedTuple):
"""Debugging info about a func, its arguments, and results."""
traced_for: str # e.g. 'jit', 'scan', etc
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
# no source location information. The first word is always the function name,
# which may be '<unknown>'.
@ -270,23 +268,25 @@ class TracingDebugInfo(NamedTuple):
# e.g., tangent args in jax.jvp.
arg_names: tuple[str | None, ...]
# The result paths are not available while we are tracing the function,
# instead we keep a thunk. Once we are done tracing, we use
# `self.resolve_result_paths()` to execute the thunk and replace the
# actual result paths.
# e.g. ('[0]', '[1]', ...)
result_paths_thunk: Callable[[], tuple[str, ...]] | None
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
@classmethod
def from_jaxpr(cls, jaxpr: core.ClosedJaxpr) -> TracingDebugInfo | None:
jaxpr_dbg = jaxpr.jaxpr._debug_info
if jaxpr_dbg is None: return None
return TracingDebugInfo(jaxpr_dbg.traced_for,
jaxpr_dbg.func_src_info,
jaxpr_dbg.arg_names,
lambda: jaxpr_dbg.result_paths)
def add_result_paths(self,
result_paths_thunk: Callable[[], tuple[str, ...]]
) -> DebugInfo:
assert self.result_paths is None
return self._replace(result_paths=HashableFunction(result_paths_thunk,
closure=()))
def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]]
) -> TracingDebugInfo:
assert self.result_paths_thunk is None
return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk,
closure=()))
def resolve_result_paths(self) -> DebugInfo:
"""Return a debug info with resolved result paths."""
if callable(self.result_paths):
return self._replace(result_paths=tuple(self.result_paths()))
return self
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
@ -296,9 +296,18 @@ class TracingDebugInfo(NamedTuple):
# TODO(necula): this should not happen
return (None,) * expected
def safe_result_paths(self, expected: int) -> tuple[str, ...]:
"""Get the result paths with a safety check."""
assert not callable(self.result_paths), self
if self.result_paths is not None and len(self.result_paths) == expected:
return self.result_paths
else:
# TODO(necula): this should not happen
return ("",) * expected
def wrap_init(f: Callable, params=None, *,
debug_info: TracingDebugInfo | None = None) -> WrappedFun:
debug_info: DebugInfo | None = None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params_dict = {} if params is None else params
params = () if params is None else tuple(sorted(params.items()))
@ -341,7 +350,7 @@ def _check_input_type(in_type: core.InputType) -> None:
provided[d.val] = True
assert all(provided)
def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None
def add_debug_info(f: WrappedFun, debug_info: DebugInfo | None
) -> WrappedFun:
"""Produce a new WrappedFun with debug_info attached."""
assert f.debug_info is None

View File

@ -413,9 +413,9 @@ class BlockSpec:
fake_index_map_args, fake_index_map_kwargs = \
index_map_tree.unflatten([False] * index_map_tree.num_leaves)
debug = api_util.tracing_debug_info("pallas_call index_map",
index_map_func, fake_index_map_args,
fake_index_map_kwargs)
debug = api_util.debug_info("pallas_call index_map",
index_map_func, fake_index_map_args,
fake_index_map_kwargs)
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
index_map_src_info = NameAndSrcInfo.from_pallas_call(

View File

@ -1100,8 +1100,8 @@ def pallas_call_checkify_rule(error: checkify.Error,
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
debug = api_util.tracing_debug_info("checkify_pallas", checked_kernel_fn,
retrace_in_avals, {})
debug = api_util.debug_info("checkify_pallas", checked_kernel_fn,
retrace_in_avals, {})
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
@ -1167,7 +1167,7 @@ def _trace_kernel_to_jaxpr(
wrapped_kernel_fun, kernel_in_transforms
)
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
debug = api_util.tracing_debug_info("pallas_call", fun, fake_kernel_args, {})
debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)
@ -1568,7 +1568,7 @@ def pallas_call(
kernel_fun_sig = api_util.fun_signature(kernel)
arg_names = None
if kernel_fun_sig:
kernel_debug_info = api_util.tracing_debug_info(
kernel_debug_info = api_util.debug_info(
"pallas_call kernel",
kernel,
[1] * len(kernel_fun_sig.parameters), {})

View File

@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, check_callable, resolve_argnums,
argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info,
argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info,
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
@ -548,7 +548,7 @@ def _infer_params_impl(
ji: PjitInfo,
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
dbg: lu.TracingDebugInfo,
dbg: core.DebugInfo,
args: tuple[Any, ...],
kwargs: dict[str, Any],
in_avals: tuple[core.AbstractValue, ...] | None,
@ -733,7 +733,7 @@ def _infer_params(
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
' together is not allowed.')
dbg = tracing_debug_info(
dbg = debug_info(
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)
@ -756,7 +756,7 @@ def _infer_params(
entry.pjit_params = p
return entry.pjit_params, entry.pjit_params.consts + dynargs
def _infer_input_type(fun: Callable, dbg: lu.TracingDebugInfo | None,
def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None,
explicit_args) -> tuple[core.AbstractValue, ...]:
avals = []
try:
@ -1302,7 +1302,7 @@ def _create_pjit_jaxpr(
fun: lu.WrappedFun,
in_type: core.InputType | Sequence[core.AbstractValue],
attr_data: int,
debug_info: lu.TracingDebugInfo,
debug_info: core.DebugInfo,
result_paths: Callable,
ignored_inline: IgnoreKey
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
@ -1346,7 +1346,7 @@ def _create_pjit_jaxpr(
def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_avals,
debug_info: core.JaxprDebugInfo | None,
debug_info: core.DebugInfo | None,
device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):

View File

@ -989,7 +989,7 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
def initial_style_jaxpr(
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
dbg: api_util.TracingDebugInfo,
dbg: core.DebugInfo,
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg)
@ -997,7 +997,7 @@ def initial_style_jaxpr(
def _initial_style_jaxpr(fun: Callable,
in_tree: api_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug: api_util.TracingDebugInfo):
debug: core.DebugInfo):
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
tree_util.treedef_tuple((in_tree,)))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
@ -1007,7 +1007,7 @@ def _initial_style_jaxpr(fun: Callable,
T = TypeVar('T')
def run_state(f: Callable[..., None]) -> Callable[[T], T]:
def wrapped(args):
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
dbg = api_util.debug_info("run_state", f, (args,), {})
flat_args, in_tree = tree_util.tree_flatten(args)
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
# There may be some uninitialized values here in ref_args.
@ -1027,7 +1027,7 @@ def run_state(f: Callable[..., None]) -> Callable[[T], T]:
def run_state_reference(f: Callable[..., None]):
def wrapped(args):
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
dbg = api_util.debug_info("run_state", f, (args,), {})
flat_args, in_tree = tree_util.tree_flatten(args)
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg)

View File

@ -20,13 +20,13 @@ from jax._src.core import (
AbstractValue as AbstractValue,
Atom as Atom,
CallPrimitive as CallPrimitive,
DebugInfo as DebugInfo,
DShapedArray as DShapedArray,
DropVar as DropVar,
Effect as Effect,
Effects as Effects,
get_opaque_trace_state as get_opaque_trace_state,
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
JaxprDebugInfo as JaxprDebugInfo,
JaxprPpContext as JaxprPpContext,
JaxprPpSettings as JaxprPpSettings,
JaxprTypeError as JaxprTypeError,

View File

@ -69,13 +69,13 @@ def _collect_jaxprs(jaxpr: core.Jaxpr,
return acc
def _debug_info_to_string(dbg: api_util.TracingDebugInfo | core.JaxprDebugInfo | None) -> list[str]:
def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]:
if dbg is None: return "None"
# Strip the absolute path and the line number but check that it references
# this file (to catch errors when the source info points in JAX internals)
fun_src_info = re.sub(r"^(\S+)( at .*/debug_info_test.py:.*)?", "\\1", dbg.func_src_info)
res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={','.join(dbg.arg_names)}"
if isinstance(dbg, core.JaxprDebugInfo):
if isinstance(dbg.result_paths, tuple):
if dbg.result_paths:
res += f", result_paths={','.join(dbg.result_paths)}"
else:
@ -221,24 +221,24 @@ class DebugInfoTest(jtu.JaxTestCase):
def my_f(x, y, z, w):
pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("x", "y", "z", "w"))
self.assertIsNone(dbg.result_paths_thunk)
self.assertIsNone(dbg.result_paths)
def test_debug_info_arg_passed_as_kwarg(self):
def my_f(x, y, z):
pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3))
dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3))
self.assertEqual(dbg.arg_names, ("x", "y", "z"))
def test_debug_info_pytrees(self):
def my_f(x_tree, *, y_tree):
pass
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),),
dict(y_tree=dict(z=3, w=4)))
dbg = api_util.debug_info("jit", my_f, ((1, 2),),
dict(y_tree=dict(z=3, w=4)))
self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]",
"y_tree['w']", "y_tree['z']"))
@ -246,43 +246,43 @@ class DebugInfoTest(jtu.JaxTestCase):
def my_f(x, y, *, z, w):
pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
static_argnums=(1,),
static_argnames=("w",))
dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x", "z"))
def test_debug_info_with_pytrees_and_statics(self):
def my_f(x, y, *, z, w):
pass
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)),
dict(z=(3, 4), w=(5, 6)),
static_argnums=(1,),
static_argnames=("w",))
dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)),
dict(z=(3, 4), w=(5, 6)),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]"))
def test_debug_info_too_many_args(self):
def my_f(x):
pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3))
dbg = api_util.debug_info("jit", my_f, (1, 2, 3), dict(z=3))
self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']"))
def test_debug_info_no_source_info_built_in(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", max, (1,), {})
dbg = api_util.debug_info("jit", max, (1,), {})
self.assertEqual(dbg.func_src_info, "max")
self.assertEqual(dbg.arg_names, ("args[0]",))
def test_debug_info_lambda(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {})
dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {})
self.assertRegex(dbg.func_src_info, r"^<lambda> at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("my_arg",))
def test_debug_info_no_source_info_not_callable(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", False, (1,), {})
dbg = api_util.debug_info("jit", False, (1,), {})
self.assertEqual(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("args[0]",))
@ -293,7 +293,7 @@ class DebugInfoTest(jtu.JaxTestCase):
def __call__(self, y):
return self.x + y
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
dbg = api_util.debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",))
@ -307,7 +307,7 @@ class DebugInfoTest(jtu.JaxTestCase):
def __repr__(self):
raise NotImplementedError
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
dbg = api_util.debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",))