diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 64978fa97..d4bcc8d66 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -323,7 +323,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True, @wraps(fun) @api_boundary def fun_remat(*args, **kwargs): - debug = api_util.tracing_debug_info( + debug = api_util.debug_info( "checkpoint / remat", fun, args, kwargs, static_argnums=static_argnums) fun_, args = _remat_static_argnums(fun, static_argnums, args) @@ -418,7 +418,7 @@ _dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached) def _trace_to_jaxpr(fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - debug: lu.TracingDebugInfo + debug: core.DebugInfo ) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]: flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree) try: @@ -447,7 +447,7 @@ def saved_residuals(f: Callable, args, kwargs = tree_unflatten(in_tree, args) return f(*args, **kwargs) - debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs) + debug_info = api_util.debug_info("saved_residuals", f, args, kwargs) out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1], return_shape=True)(*in_leaves) assert isinstance(out, tuple) diff --git a/jax/_src/api.py b/jax/_src/api.py index 3c9c4d9f7..1c9a706a9 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -57,11 +57,11 @@ from jax._src import pjit from jax._src import xla_bridge as xb from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray from jax._src.api_util import ( - flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, - flatten_axes, donation_vector, - rebase_donate_argnums, _ensure_index, _ensure_index_tuple, - apply_flat_fun_nokwargs, check_callable, tracing_debug_info, - result_paths, flat_out_axes) + flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, + flatten_axes, donation_vector, + rebase_donate_argnums, _ensure_index, _ensure_index_tuple, + apply_flat_fun_nokwargs, check_callable, debug_info, + result_paths, flat_out_axes) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc @@ -452,7 +452,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, raise TypeError(f"differentiating with respect to {argnums=} requires at least " f"{max_argnum + 1} positional arguments to be passed by the caller, " f"but got only {len(args)} positional arguments.") - dbg = tracing_debug_info('value_and_grad', fun, args, kwargs) + dbg = debug_info('value_and_grad', fun, args, kwargs) f = lu.wrap_init(fun, params=kwargs, debug_info=dbg) f_partial, dyn_args = argnums_partial(f, argnums, args, @@ -1426,7 +1426,7 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple, if in_devices is not None and len(in_devices) == 0: raise ValueError("'devices' argument to pmap must be non-empty, or None.") - dbg = tracing_debug_info( + dbg = debug_info( "pmap", fun, args, kwargs, static_argnums=static_broadcasted_tuple) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 029131176..99bbe2886 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -31,7 +31,6 @@ from jax._src.tree_util import ( prefix_errors) from jax._src.tree_util import _replace_nones from jax._src import linear_util as lu -from jax._src.linear_util import TracingDebugInfo from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, Unhashable, safe_zip) from jax._src import traceback_util @@ -582,7 +581,7 @@ def api_hook(fun, tag: str): return fun -def tracing_debug_info( +def debug_info( traced_for: str, fun: Callable, args: Sequence[Any], @@ -594,14 +593,14 @@ def tracing_debug_info( # TODO(necula): check if we really need this, e.g., to speed up tracing. sourceinfo: str | None = None, signature: inspect.Signature | None = None, -) -> TracingDebugInfo: +) -> core.DebugInfo: if sourceinfo is None: sourceinfo = fun_sourceinfo(fun) if signature is None: signature = fun_signature(fun) arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums, static_argnames) - return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) + return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) def fun_signature(fun: Callable) -> inspect.Signature | None: @@ -619,7 +618,7 @@ _fun_name_re = re.compile(r"(?:)") # TODO(mattjj): make this function internal to this module def fun_sourceinfo(fun: Callable) -> str: - # See TracingDebugInfo.fun_src_info + # See DebugInfo.fun_src_info res = getattr(fun, "__fun_sourceinfo__", None) if res is not None: return res while isinstance(fun, partial): @@ -684,20 +683,19 @@ def result_paths(_fun, _store, *args, **kwargs): # TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr def add_jaxpr_debug_info(jaxpr: core.Jaxpr, - trace_debug: TracingDebugInfo | None, + debug: core.DebugInfo | None, result_paths: tuple[str, ...] | None = None, ) -> core.Jaxpr: """Add debug info to jaxpr, given trace-time debug info and result paths.""" - if trace_debug is None: + if debug is None: return jaxpr # TODO(necula): re-enable this safety check # assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None) - if result_paths is None: - result_paths = trace_debug.result_paths_thunk() # type: ignore - debug_info = core.JaxprDebugInfo( - trace_debug.traced_for, trace_debug.func_src_info, - trace_debug.arg_names, tuple(result_paths)) # type: ignore - return jaxpr.replace(debug_info=debug_info) + if result_paths is not None: + debug = debug._replace(result_paths=tuple(result_paths)) + else: + debug = debug.resolve_result_paths() + return jaxpr.replace(debug_info=debug) def hoist_obj_attrs(f, flat_args): idxs, objs, flat_args_ = [], [], [] diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index f4fe0edbf..80fd1b041 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -1202,7 +1202,7 @@ def checkify(f: Callable[..., Out], in_tree = jtu.tree_structure(((), {})) closed_f = lambda: f(*args, **kwargs) # stage: - debug = api_util.tracing_debug_info("checkify", f, args, kwargs) + debug = api_util.debug_info("checkify", f, args, kwargs) fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f, debug_info=debug), in_tree) diff --git a/jax/_src/core.py b/jax/_src/core.py index 25f7c4229..5f0d86305 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -82,31 +82,7 @@ EffectTypeSet = effects.EffectTypeSet no_effects: Effects = effects.no_effects -# TODO(necula): make this an extension of TracingDebugInfo -class JaxprDebugInfo(NamedTuple): - # An extension of lu.TracingDebugInfo; see comments there - traced_for: str - func_src_info: str - arg_names: tuple[str | None, ...] - # This is formed after tracing, when we have concrete `result_paths` - result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...) - - def safe_arg_names(self, expected: int) -> tuple[str | None, ...]: - """Get the arg_names with a safety check.""" - if len(self.arg_names) == expected: - return self.arg_names - else: - # TODO(necula): this should not happen - return (None,) * expected - - def safe_result_paths(self, expected: int) -> tuple[str | None, ...]: - """Get the result_paths with a safety check.""" - if len(self.result_paths) == expected: - return self.result_paths - else: - # TODO(necula): this should not happen - return ("",) * expected - +DebugInfo = lu.DebugInfo class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', @@ -117,7 +93,7 @@ class Jaxpr: _outvars: list[Atom] _eqns: list[JaxprEqn] _effects: Effects - _debug_info: JaxprDebugInfo | None + _debug_info: DebugInfo | None @property def constvars(self) -> list[Var]: @@ -140,13 +116,13 @@ class Jaxpr: return self._effects @property - def debug_info(self) -> JaxprDebugInfo | None: + def debug_info(self) -> DebugInfo | None: return self._debug_info def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], outvars: Sequence[Atom], eqns: Sequence[JaxprEqn], effects: Effects = no_effects, - debug_info: JaxprDebugInfo | None = None): + debug_info: DebugInfo | None = None): """ Args: constvars: list of variables introduced for constants. Array constants are @@ -157,14 +133,14 @@ class Jaxpr: eqns: list of equations. effects: set of effects. The effects on a jaxpr are a superset of the union of the effects for each equation. - debug_info: optional JaxprDebugInfo. + debug_info: optional DebugInfo. """ self._constvars = list(constvars) self._invars = list(invars) self._outvars = list(outvars) self._eqns = list(eqns) self._effects = effects - self._debug_info = debug_info + self._debug_info = debug_info and debug_info.resolve_result_paths() # TODO(necula): re-enable these safety checks # assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars) # assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 175182403..040f86944 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -147,7 +147,7 @@ class custom_vmap: raise AttributeError( f"No batching rule defined for custom_vmap function {fun_name} " "using def_vmap.") - debug = api_util.tracing_debug_info("custom_vmap", self.fun, args, {}) + debug = api_util.debug_info("custom_vmap", self.fun, args, {}) args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs( lu.wrap_init(self.fun, debug_info=debug), diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index 7e8bbc290..9e666f8f6 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -127,12 +127,12 @@ class custom_dce: "def_dce." ) rule_name = util.fun_name(self.dce_rule) - debug = api_util.tracing_debug_info("custom_dce", self.fun, - args, {}, - static_argnums=self.static_argnums) - debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule, - args, {}, - static_argnums=self.static_argnums) + debug = api_util.debug_info("custom_dce", self.fun, + args, {}, + static_argnums=self.static_argnums) + debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule, + args, {}, + static_argnums=self.static_argnums) args = api_util.resolve_kwargs(self.fun, args, kwargs) if self.static_argnums: static_argnums = set(self.static_argnums) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 3cbf27d8c..c94066725 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -468,9 +468,9 @@ class custom_partitioning: def __call__(self, *args, **kwargs): args = _resolve_kwargs(self.fun, args, kwargs) - debug = api_util.tracing_debug_info("custom_partitioning", self.fun, - args, kwargs, - static_argnums=self.static_argnums) + debug = api_util.debug_info("custom_partitioning", self.fun, + args, kwargs, + static_argnums=self.static_argnums) if self.static_argnums: static_argnums = set(self.static_argnums) args = tuple(x if i in static_argnums else x for i, x in enumerate(args)) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 46aad30ac..668d2a078 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -147,7 +147,7 @@ def _linearize_jaxpr( jaxpr: core.ClosedJaxpr, nonzeros: tuple[bool, ...] ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]: - dbg = lu.TracingDebugInfo.from_jaxpr(jaxpr) + dbg = jaxpr.jaxpr.debug_info primal_trace = pe.DynamicJaxprTrace(dbg) tangent_trace = pe.DynamicJaxprTrace(dbg) lin_trace = LinearizeTrace(primal_trace, tangent_trace) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 3777ce3d3..7dd137975 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -42,7 +42,6 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.state.types import AbstractRef -from jax._src import tree_util from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, @@ -932,7 +931,7 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr, in_unknowns: Sequence[bool], instantiate: bool | Sequence[bool]): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), - debug_info=lu.TracingDebugInfo.from_jaxpr(jaxpr)) + debug_info=jaxpr.jaxpr.debug_info) cell = [] def fun(*known_vals_in): @@ -1334,10 +1333,11 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr: def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr: outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b] - dbg = jaxpr.debug_info and core.JaxprDebugInfo( + dbg = jaxpr.debug_info and core.DebugInfo( jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, jaxpr.debug_info.arg_names, - tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b)) + tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)), + used_outputs) if b)) new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg) config.enable_checks.value and core.check_jaxpr(new_jaxpr) return new_jaxpr @@ -1422,10 +1422,12 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...], eqns = new_eqns[::-1] jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns) - dbg = jaxpr.debug_info and core.JaxprDebugInfo( + dbg = jaxpr.debug_info and core.DebugInfo( jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, - tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b), - tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b)) + tuple(v for v, b in zip(jaxpr.debug_info.safe_arg_names(len(used_inputs)), + used_inputs) if b), + tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)), + used_outputs) if b)) new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg) config.enable_checks.value and core.check_jaxpr(new_jaxpr) @@ -1623,9 +1625,9 @@ class JaxprStackFrame: attrs_tracked: list[tuple[Any, str]] attrs_inits: list attrs_vars: list[Var] - debug_info: lu.TracingDebugInfo | None + debug_info: core.DebugInfo | None - def __init__(self, debug_info: lu.TracingDebugInfo | None): + def __init__(self, debug_info: core.DebugInfo | None): self.gensym = core.gensym() self.tracer_to_var = {} self.constid_to_tracer = {} @@ -1809,7 +1811,7 @@ def _inline_literals( class DynamicJaxprTrace(core.Trace): __slots__ = ("frame",) - def __init__(self, debug_info: lu.TracingDebugInfo | None): + def __init__(self, debug_info: core.DebugInfo | None): self.frame = JaxprStackFrame(debug_info) def invalidate(self): @@ -2114,7 +2116,7 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): def trace_to_jaxpr_dynamic( fun: lu.WrappedFun, in_avals: Sequence[AbstractValue], - debug_info: lu.TracingDebugInfo | None = None, + debug_info: core.DebugInfo | None = None, *, keep_inputs: list[bool] | None = None, ) -> tuple[Jaxpr, list[AbstractValue], list[Any], @@ -2137,7 +2139,7 @@ def trace_to_jaxpr_dynamic( return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked def _check_no_returned_refs( - dbg: lu.TracingDebugInfo | None, + dbg: core.DebugInfo | None, out_tracers: Sequence[DynamicJaxprTracer] ) -> None: if not config.mutable_array_checks.value: return @@ -2148,10 +2150,8 @@ def _check_no_returned_refs( raise ValueError( f"function returned a mutable array reference of type {a.str_short()}, " "but mutable array references cannot be returned.") - loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore - if (dbg.result_paths_thunk and - (ls := dbg.result_paths_thunk()) and - ls[i]) else '') + result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers)) + loc = f' at output tree path {result_paths[i]}' frame = t._trace.frame v = frame.tracer_to_var.get(id(t)) eqn = next((e for e in frame.eqns if v in e.outvars), None) @@ -2172,7 +2172,7 @@ def _check_no_returned_refs( @profiler.annotate_function def trace_to_jaxpr_dynamic2( - fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None + fun: lu.WrappedFun, debug_info: core.DebugInfo | None = None ) -> tuple[Jaxpr, OutputType, list[Any]]: trace = DynamicJaxprTrace(debug_info) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 389c20734..04c950aef 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -652,7 +652,7 @@ class ParallelCallableInfo: in_axes: Iterable[int | None] out_axes_thunk: Callable[[], Sequence[int | None]] avals: Sequence[core.AbstractValue] - debug_info: api_util.TracingDebugInfo | None + debug_info: core.DebugInfo | None @cached_property def local_devices(self): @@ -964,7 +964,7 @@ class UnloadedPmapExecutable: ordered_effects: list[core.Effect] keepalive: Sequence[Any] host_callbacks: Sequence[Any] - jaxpr_debug_info: core.JaxprDebugInfo + jaxpr_debug_info: core.DebugInfo def build_execute_fun(self): input_indices = [] @@ -1004,7 +1004,7 @@ class UnloadedPmapExecutable: ordered_effects: list[core.Effect], host_callbacks: list[Any], keepalive: Any, - jaxpr_debug_info: core.JaxprDebugInfo, + jaxpr_debug_info: core.DebugInfo, platforms: Sequence[str], shape_poly_state: mlir.ShapePolyLoweringState | None = None, compiler_options=None): @@ -2127,7 +2127,7 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]] class AllArgsInfo(NamedTuple): """Avals and debug_info for all arguments prior to DCE.""" in_avals: Sequence[core.ShapedArray] - debug_info: core.JaxprDebugInfo | None + debug_info: core.DebugInfo | None @lru_cache(maxsize=2048) @@ -3199,7 +3199,7 @@ def cc_shard_arg(x, sharding, layout): def check_arg_avals_for_call(ref_avals, arg_avals, - jaxpr_debug_info: core.JaxprDebugInfo | None = None): + jaxpr_debug_info: core.DebugInfo | None = None): if len(ref_avals) != len(arg_avals): raise TypeError( f"Computation compiled for {len(ref_avals)} inputs " @@ -3258,7 +3258,7 @@ def check_array_xla_sharding_layout_match( args_after_dce, in_xla_shardings: Sequence[JSharding], in_xla_layouts: Sequence[DeviceLocalLayout], - jaxpr_debug_info: core.JaxprDebugInfo | None, + jaxpr_debug_info: core.DebugInfo | None, kept_var_idx: set[int]) -> None: from jax._src.array import ArrayImpl # jaxpr_debug_info.arg_names are before DCE, so need to DCE them. diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 9b11fb930..8844e7f3d 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -53,7 +53,7 @@ def _typecheck_param(prim, param, name, msg_required, pred): def _initial_style_open_jaxpr(fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - debug_info: api_util.TracingDebugInfo): + debug_info: core.DebugInfo): wrapped_fun, out_tree = api_util.flatten_fun_nokwargs( lu.wrap_init(fun, debug_info=debug_info), in_tree) @@ -65,7 +65,7 @@ def _initial_style_open_jaxpr(fun: Callable, def _initial_style_jaxpr(fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - debug_info: api_util.TracingDebugInfo): + debug_info: core.DebugInfo): jaxpr, consts, out_tree, () = _initial_style_open_jaxpr( fun, in_tree, in_avals, debug_info) closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) @@ -74,7 +74,7 @@ def _initial_style_jaxpr(fun: Callable, def _initial_style_jaxpr_attrs(fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - debug_info: api_util.TracingDebugInfo): + debug_info: core.DebugInfo): jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr( fun, in_tree, in_avals, debug_info) closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) @@ -83,7 +83,7 @@ def _initial_style_jaxpr_attrs(fun: Callable, def _initial_style_jaxprs_with_common_consts( funs: Sequence[Callable], in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - debug_infos: Sequence[api_util.TracingDebugInfo]): + debug_infos: Sequence[core.DebugInfo]): # When staging the branches of a conditional into jaxprs, constants are # extracted from each branch and converted to jaxpr arguments. To use the # staged jaxprs as the branches to a conditional *primitive*, we need for diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 4236a71a3..d0b681fca 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -134,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands, if (config.disable_jit.value and core.is_concrete(index)): return branches[int(index)](*operands) - dbgs = [api_util.tracing_debug_info("switch", branch, operands, {}) + dbgs = [api_util.debug_info("switch", branch, operands, {}) for branch in branches] ops, ops_tree = tree_flatten(operands) ops_avals = tuple(map(core.get_aval, ops)) @@ -237,10 +237,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, ops, ops_tree = tree_flatten(operands) ops_avals = tuple(map(core.get_aval, ops)) - dbg_true_fun = api_util.tracing_debug_info("cond", true_fun, operands, {}) + dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {}) if config.mutable_array_checks.value: api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops) - dbg_false_fun = api_util.tracing_debug_info("cond", false_fun, operands, {}) + dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {}) jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( (true_fun, false_fun), ops_tree, ops_avals, [dbg_true_fun, dbg_false_fun]) diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index d4d387ddb..4a3b6d0d7 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -195,7 +195,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]], def _create_jaxpr(init): init_flat = tree_leaves(init) _, in_tree = tree_flatten((init, xs)) - dbg = api_util.tracing_debug_info("scan", f, (init, xs), {}) + dbg = api_util.debug_info("scan", f, (init, xs), {}) carry_avals = tuple(map(core.get_aval, init_flat)) jaxpr, _, out_tree = _initial_style_jaxpr( f, in_tree, carry_avals + x_avals, dbg) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index cde578a58..9f7e52bc9 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -273,7 +273,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]], return carry, stacked_y x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] - dbg_body = api_util.tracing_debug_info("scan", f, (init, xs), {}) + dbg_body = api_util.debug_info("scan", f, (init, xs), {}) if config.mutable_array_checks.value: in_flat, in_tree = tree_flatten((init, xs)) @@ -1357,10 +1357,10 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric], def _create_jaxpr(init_val): init_vals, in_tree = tree_flatten((init_val,)) init_avals = tuple(_map(core.get_aval, init_vals)) - cond_dbg = api_util.tracing_debug_info("while_cond", cond_fun, (init_val,), {}) + cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {}) cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( cond_fun, in_tree, init_avals, cond_dbg) - body_dbg = api_util.tracing_debug_info("while_body", body_fun, (init_val,), {}) + body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {}) body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( body_fun, in_tree, init_avals, body_dbg) if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 29620cb50..5094a64f3 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -93,16 +93,16 @@ def custom_root(f: Callable, """ guess_flat, in_args_tree = tree_flatten((initial_guess,)) guess_avals = tuple(_map(core.get_aval, guess_flat)) - f_debug = api_util.tracing_debug_info("custom_root", f, (initial_guess,), {}) + f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {}) f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( f, in_args_tree, guess_avals, f_debug) in_tree, = treedef_children(in_args_tree) _check_tree("f", "initial_guess", out_tree, in_tree, False) - solve_debug = api_util.tracing_debug_info("custom_root solve", solve, - (f, initial_guess), {}, - static_argnums=(0,)) + solve_debug = api_util.debug_info("custom_root solve", solve, + (f, initial_guess), {}, + static_argnums=(0,)) solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( partial(solve, f), in_args_tree, guess_avals, solve_debug) _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) @@ -111,10 +111,10 @@ def custom_root(f: Callable, unchecked_zeros, f_jvp = api.linearize(f, x) return tangent_solve(f_jvp, b) - tangent_solve_debug = api_util.tracing_debug_info("custom_root tangent_solve", - tangent_solve, - (f, initial_guess), {}, - static_argnums=(0,)) + tangent_solve_debug = api_util.debug_info("custom_root tangent_solve", + tangent_solve, + (f, initial_guess), {}, + static_argnums=(0,)) l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2, tangent_solve_debug) @@ -265,17 +265,17 @@ def custom_linear_solve( return f_aux if has_aux else f - matvec_debug = api_util.tracing_debug_info("custom_linear_solve", - matvec, (b,), {}) + matvec_debug = api_util.debug_info("custom_linear_solve", + matvec, (b,), {}) # no auxiliary data assumed for matvec matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( _shape_checked(matvec, "matvec", False), in_args_tree, b_avals, matvec_debug) _check_tree("matvec", "b", out_tree, tree, False) - solve_debug = api_util.tracing_debug_info("custom_linear_solve solve", - solve, (matvec, b), {}, - static_argnums=(0,)) + solve_debug = api_util.debug_info("custom_linear_solve solve", + solve, (matvec, b), {}, + static_argnums=(0,)) solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, solve_debug) @@ -285,7 +285,7 @@ def custom_linear_solve( vecmat_jaxpr = tr_solve_jaxpr = None vecmat_consts = tr_solve_consts = [] else: - transpose_solve_debug = api_util.tracing_debug_info( + transpose_solve_debug = api_util.debug_info( "custom_linear_solve transpose_solve", transpose_solve, (matvec, b), {}, static_argnums=(0,)) if symmetric: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a3cd5e3a5..9946a9a67 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -745,7 +745,7 @@ def _trace_composite_to_jaxpr(fun: Callable, in_tree: tree_util.PyTreeDef, in_avals: Sequence[core.AbstractValue], name: str, - debug_info: api_util.TracingDebugInfo): + debug_info: core.DebugInfo): flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info) if any(isinstance(c, core.Tracer) for c in consts): @@ -822,8 +822,8 @@ def composite( """ @functools.wraps(decomposition) def _decorator(*args, **kwargs): - debug_info = api_util.tracing_debug_info("composite", decomposition, - args, kwargs) + debug_info = api_util.debug_info("composite", decomposition, + args, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) in_avals = tuple(core.get_aval(x) for x in flat_args) closed_jaxpr, out_tree = _trace_composite_to_jaxpr( diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index ac267de85..018ab41a4 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -156,7 +156,7 @@ class WrappedFun: f_transformed: Callable, transforms, stores: tuple[Store | EqualStore | None, ...], params, in_type, - debug_info: TracingDebugInfo | None): + debug_info: DebugInfo | None): self.f = f self.f_transformed = f_transformed self.transforms = transforms @@ -253,12 +253,10 @@ def fun_name(f): except: return str(f) -class TracingDebugInfo(NamedTuple): - """Tracing-time debugging info about a func and its arguments. - - Formed just before staging to a jaxpr and read in trace-time error messages. - """ +class DebugInfo(NamedTuple): + """Debugging info about a func, its arguments, and results.""" traced_for: str # e.g. 'jit', 'scan', etc + # e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have # no source location information. The first word is always the function name, # which may be ''. @@ -270,23 +268,25 @@ class TracingDebugInfo(NamedTuple): # e.g., tangent args in jax.jvp. arg_names: tuple[str | None, ...] + # The result paths are not available while we are tracing the function, + # instead we keep a thunk. Once we are done tracing, we use + # `self.resolve_result_paths()` to execute the thunk and replace the + # actual result paths. # e.g. ('[0]', '[1]', ...) - result_paths_thunk: Callable[[], tuple[str, ...]] | None + result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None - @classmethod - def from_jaxpr(cls, jaxpr: core.ClosedJaxpr) -> TracingDebugInfo | None: - jaxpr_dbg = jaxpr.jaxpr._debug_info - if jaxpr_dbg is None: return None - return TracingDebugInfo(jaxpr_dbg.traced_for, - jaxpr_dbg.func_src_info, - jaxpr_dbg.arg_names, - lambda: jaxpr_dbg.result_paths) + def add_result_paths(self, + result_paths_thunk: Callable[[], tuple[str, ...]] + ) -> DebugInfo: + assert self.result_paths is None + return self._replace(result_paths=HashableFunction(result_paths_thunk, + closure=())) - def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]] - ) -> TracingDebugInfo: - assert self.result_paths_thunk is None - return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk, - closure=())) + def resolve_result_paths(self) -> DebugInfo: + """Return a debug info with resolved result paths.""" + if callable(self.result_paths): + return self._replace(result_paths=tuple(self.result_paths())) + return self def safe_arg_names(self, expected: int) -> tuple[str | None, ...]: """Get the arg_names with a safety check.""" @@ -296,9 +296,18 @@ class TracingDebugInfo(NamedTuple): # TODO(necula): this should not happen return (None,) * expected + def safe_result_paths(self, expected: int) -> tuple[str, ...]: + """Get the result paths with a safety check.""" + assert not callable(self.result_paths), self + if self.result_paths is not None and len(self.result_paths) == expected: + return self.result_paths + else: + # TODO(necula): this should not happen + return ("",) * expected + def wrap_init(f: Callable, params=None, *, - debug_info: TracingDebugInfo | None = None) -> WrappedFun: + debug_info: DebugInfo | None = None) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" params_dict = {} if params is None else params params = () if params is None else tuple(sorted(params.items())) @@ -341,7 +350,7 @@ def _check_input_type(in_type: core.InputType) -> None: provided[d.val] = True assert all(provided) -def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None +def add_debug_info(f: WrappedFun, debug_info: DebugInfo | None ) -> WrappedFun: """Produce a new WrappedFun with debug_info attached.""" assert f.debug_info is None diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 267095da0..cc156f375 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -413,9 +413,9 @@ class BlockSpec: fake_index_map_args, fake_index_map_kwargs = \ index_map_tree.unflatten([False] * index_map_tree.num_leaves) - debug = api_util.tracing_debug_info("pallas_call index_map", - index_map_func, fake_index_map_args, - fake_index_map_kwargs) + debug = api_util.debug_info("pallas_call index_map", + index_map_func, fake_index_map_args, + fake_index_map_kwargs) flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( lu.wrap_init(index_map_func, debug_info=debug), index_map_tree) index_map_src_info = NameAndSrcInfo.from_pallas_call( diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 4bffcd8d7..98fafce09 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1100,8 +1100,8 @@ def pallas_call_checkify_rule(error: checkify.Error, jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(checked_kernel_fn), jaxpr_in_tree) - debug = api_util.tracing_debug_info("checkify_pallas", checked_kernel_fn, - retrace_in_avals, {}) + debug = api_util.debug_info("checkify_pallas", checked_kernel_fn, + retrace_in_avals, {}) with pallas_core.tracing_grid_env(grid_mapping.grid, ()): final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( wrapped_kernel_with_err, jaxpr_flat_avals, debug) @@ -1167,7 +1167,7 @@ def _trace_kernel_to_jaxpr( wrapped_kernel_fun, kernel_in_transforms ) fake_kernel_args = kernel_in_tree.unflatten(kernel_avals) - debug = api_util.tracing_debug_info("pallas_call", fun, fake_kernel_args, {}) + debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {}) with grid_mapping.trace_env(): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, kernel_avals, debug) @@ -1568,7 +1568,7 @@ def pallas_call( kernel_fun_sig = api_util.fun_signature(kernel) arg_names = None if kernel_fun_sig: - kernel_debug_info = api_util.tracing_debug_info( + kernel_debug_info = api_util.debug_info( "pallas_call kernel", kernel, [1] * len(kernel_fun_sig.parameters), {}) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index c192168ce..b435bfdc9 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb from jax._src.api_util import ( argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, donation_vector, check_callable, resolve_argnums, - argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info, + argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info, hoist_obj_attrs, _check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) from jax._src.interpreters import partial_eval as pe @@ -548,7 +548,7 @@ def _infer_params_impl( ji: PjitInfo, pjit_mesh: mesh_lib.Mesh | None, resource_env: mesh_lib.ResourceEnv | None, - dbg: lu.TracingDebugInfo, + dbg: core.DebugInfo, args: tuple[Any, ...], kwargs: dict[str, Any], in_avals: tuple[core.AbstractValue, ...] | None, @@ -733,7 +733,7 @@ def _infer_params( 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' ' together is not allowed.') - dbg = tracing_debug_info( + dbg = debug_info( 'jit', fun, args, kwargs, static_argnums=ji.static_argnums, static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, signature=ji.fun_signature) @@ -756,7 +756,7 @@ def _infer_params( entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs -def _infer_input_type(fun: Callable, dbg: lu.TracingDebugInfo | None, +def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None, explicit_args) -> tuple[core.AbstractValue, ...]: avals = [] try: @@ -1302,7 +1302,7 @@ def _create_pjit_jaxpr( fun: lu.WrappedFun, in_type: core.InputType | Sequence[core.AbstractValue], attr_data: int, - debug_info: lu.TracingDebugInfo, + debug_info: core.DebugInfo, result_paths: Callable, ignored_inline: IgnoreKey ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], @@ -1346,7 +1346,7 @@ def _create_pjit_jaxpr( def _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, out_layouts_leaves, out_tree, out_avals, - debug_info: core.JaxprDebugInfo | None, + debug_info: core.DebugInfo | None, device_or_backend_set): orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)): diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index a95a4b74c..3f8d43707 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -989,7 +989,7 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue], def initial_style_jaxpr( fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - dbg: api_util.TracingDebugInfo, + dbg: core.DebugInfo, ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]: return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg) @@ -997,7 +997,7 @@ def initial_style_jaxpr( def _initial_style_jaxpr(fun: Callable, in_tree: api_util.PyTreeDef, in_avals: Sequence[core.AbstractValue], - debug: api_util.TracingDebugInfo): + debug: core.DebugInfo): fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), tree_util.treedef_tuple((in_tree,))) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) @@ -1007,7 +1007,7 @@ def _initial_style_jaxpr(fun: Callable, T = TypeVar('T') def run_state(f: Callable[..., None]) -> Callable[[T], T]: def wrapped(args): - dbg = api_util.tracing_debug_info("run_state", f, (args,), {}) + dbg = api_util.debug_info("run_state", f, (args,), {}) flat_args, in_tree = tree_util.tree_flatten(args) ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) # There may be some uninitialized values here in ref_args. @@ -1027,7 +1027,7 @@ def run_state(f: Callable[..., None]) -> Callable[[T], T]: def run_state_reference(f: Callable[..., None]): def wrapped(args): - dbg = api_util.tracing_debug_info("run_state", f, (args,), {}) + dbg = api_util.debug_info("run_state", f, (args,), {}) flat_args, in_tree = tree_util.tree_flatten(args) ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg) diff --git a/jax/core.py b/jax/core.py index ef1551b2f..3fd7af440 100644 --- a/jax/core.py +++ b/jax/core.py @@ -20,13 +20,13 @@ from jax._src.core import ( AbstractValue as AbstractValue, Atom as Atom, CallPrimitive as CallPrimitive, + DebugInfo as DebugInfo, DShapedArray as DShapedArray, DropVar as DropVar, Effect as Effect, Effects as Effects, get_opaque_trace_state as get_opaque_trace_state, InconclusiveDimensionOperation as InconclusiveDimensionOperation, - JaxprDebugInfo as JaxprDebugInfo, JaxprPpContext as JaxprPpContext, JaxprPpSettings as JaxprPpSettings, JaxprTypeError as JaxprTypeError, diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 251fbba2d..4780c703c 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -69,13 +69,13 @@ def _collect_jaxprs(jaxpr: core.Jaxpr, return acc -def _debug_info_to_string(dbg: api_util.TracingDebugInfo | core.JaxprDebugInfo | None) -> list[str]: +def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]: if dbg is None: return "None" # Strip the absolute path and the line number but check that it references # this file (to catch errors when the source info points in JAX internals) fun_src_info = re.sub(r"^(\S+)( at .*/debug_info_test.py:.*)?", "\\1", dbg.func_src_info) res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={','.join(dbg.arg_names)}" - if isinstance(dbg, core.JaxprDebugInfo): + if isinstance(dbg.result_paths, tuple): if dbg.result_paths: res += f", result_paths={','.join(dbg.result_paths)}" else: @@ -221,24 +221,24 @@ class DebugInfoTest(jtu.JaxTestCase): def my_f(x, y, z, w): pass - dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) + dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) - self.assertIsNone(dbg.result_paths_thunk) + self.assertIsNone(dbg.result_paths) def test_debug_info_arg_passed_as_kwarg(self): def my_f(x, y, z): pass - dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3)) + dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3)) self.assertEqual(dbg.arg_names, ("x", "y", "z")) def test_debug_info_pytrees(self): def my_f(x_tree, *, y_tree): pass - dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),), - dict(y_tree=dict(z=3, w=4))) + dbg = api_util.debug_info("jit", my_f, ((1, 2),), + dict(y_tree=dict(z=3, w=4))) self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]", "y_tree['w']", "y_tree['z']")) @@ -246,43 +246,43 @@ class DebugInfoTest(jtu.JaxTestCase): def my_f(x, y, *, z, w): pass - dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4), - static_argnums=(1,), - static_argnames=("w",)) + dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4), + static_argnums=(1,), + static_argnames=("w",)) self.assertEqual(dbg.arg_names, ("x", "z")) def test_debug_info_with_pytrees_and_statics(self): def my_f(x, y, *, z, w): pass - dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)), - dict(z=(3, 4), w=(5, 6)), - static_argnums=(1,), - static_argnames=("w",)) + dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)), + dict(z=(3, 4), w=(5, 6)), + static_argnums=(1,), + static_argnames=("w",)) self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) def test_debug_info_too_many_args(self): def my_f(x): pass - dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3)) + dbg = api_util.debug_info("jit", my_f, (1, 2, 3), dict(z=3)) self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) def test_debug_info_no_source_info_built_in(self): # built-in function "int" does not have an inspect.Signature - dbg = api_util.tracing_debug_info("jit", max, (1,), {}) + dbg = api_util.debug_info("jit", max, (1,), {}) self.assertEqual(dbg.func_src_info, "max") self.assertEqual(dbg.arg_names, ("args[0]",)) def test_debug_info_lambda(self): # built-in function "int" does not have an inspect.Signature - dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {}) + dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {}) self.assertRegex(dbg.func_src_info, r"^ at .*debug_info_test.py:\d+") self.assertEqual(dbg.arg_names, ("my_arg",)) def test_debug_info_no_source_info_not_callable(self): # built-in function "int" does not have an inspect.Signature - dbg = api_util.tracing_debug_info("jit", False, (1,), {}) + dbg = api_util.debug_info("jit", False, (1,), {}) self.assertEqual(dbg.func_src_info, "") self.assertEqual(dbg.arg_names, ("args[0]",)) @@ -293,7 +293,7 @@ class DebugInfoTest(jtu.JaxTestCase): def __call__(self, y): return self.x + y - dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {}) + dbg = api_util.debug_info("jit", Foo(), (1,), {}) self.assertRegex(dbg.func_src_info, "") self.assertEqual(dbg.arg_names, ("y",)) @@ -307,7 +307,7 @@ class DebugInfoTest(jtu.JaxTestCase): def __repr__(self): raise NotImplementedError - dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {}) + dbg = api_util.debug_info("jit", Foo(), (1,), {}) self.assertRegex(dbg.func_src_info, "") self.assertEqual(dbg.arg_names, ("y",))