diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index a9f348fe7..a55432c5f 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -32,8 +32,7 @@ from jax._src import linear_util as lu from jax._src import effects from jax._src import source_info_util from jax._src import traceback_util -from jax._src.api_util import ( - flatten_fun, debug_info, fun_sourceinfo, fun_signature) +from jax._src import api_util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -42,7 +41,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import convolution as lax_convolution from jax._src.lib.mlir.dialects import hlo from jax._src.traceback_util import api_boundary -from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure +from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten, tree_structure from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, safe_zip, merge_lists, weakref_lru_cache) @@ -324,8 +323,9 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True, @wraps(fun) @api_boundary def fun_remat(*args, **kwargs): - debug = debug_info("checkpoint / remat", fun_sourceinfo(fun), - fun_signature(fun), args, kwargs, static_argnums, ()) + debug = api_util.tracing_debug_info( + "checkpoint / remat", fun, + args, kwargs, static_argnums=static_argnums) fun_, args = _remat_static_argnums(fun, static_argnums, args) args_flat, in_tree = tree_flatten((args, kwargs)) in_avals = [core.shaped_abstractify(x) for x in args_flat] @@ -415,8 +415,12 @@ _dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached) # This helper is similar to those in control_flow/common.py, but with # remat-specific errors. @weakref_lru_cache -def _trace_to_jaxpr(fun, in_tree, in_avals, debug): - flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) +def _trace_to_jaxpr(fun: Callable, + in_tree: PyTreeDef, + in_avals: Sequence[core.AbstractValue], + debug: lu.TracingDebugInfo + ) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]: + flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree) try: jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) except core.ConcretizationTypeError as e: @@ -530,7 +534,8 @@ ad.primitive_jvps[remat_p] = remat_jvp effects.remat_allowed_effects.add_type(lax_internal.InOutFeedEffect) -def remat_partial_eval(trace, *tracers, jaxpr, **params): +def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, + jaxpr: core.Jaxpr, **params): assert not jaxpr.constvars disallowed_effects = effects.remat_allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: @@ -567,7 +572,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): # set up unknown outputs with a recipe to call remat res_tracers = map(trace.new_instantiated_const, residuals) _, tracers_staged = partition_list(in_used_staged, tracers) - in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged) + in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged) # type: ignore out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) diff --git a/jax/_src/api.py b/jax/_src/api.py index ce916597f..3ed68c850 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -42,7 +42,6 @@ from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, prefix_errors, generate_key_paths, tree_flatten_with_path) -from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dispatch @@ -61,8 +60,8 @@ from jax._src.api_util import ( flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, - apply_flat_fun_nokwargs, check_callable, debug_info, - result_paths, flat_out_axes, debug_info_final, fun_sourceinfo) + apply_flat_fun_nokwargs, check_callable, tracing_debug_info, + result_paths, flat_out_axes, debug_info_final) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc @@ -456,10 +455,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, raise TypeError(f"differentiating with respect to {argnums=} requires at least " f"{max_argnum + 1} positional arguments to be passed by the caller, " f"but got only {len(args)} positional arguments.") - fun_src_info = fun_sourceinfo(fun) - fun_signature = api_util.fun_signature(fun) - dbg = debug_info('value_and_grad', fun_src_info, fun_signature, - args, kwargs, (), ()) + dbg = tracing_debug_info('value_and_grad', fun, args, kwargs) f = lu.wrap_init(fun, params=kwargs, debug_info=dbg) f_partial, dyn_args = argnums_partial(f, argnums, args, @@ -1405,11 +1401,9 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, if in_devices is not None and len(in_devices) == 0: raise ValueError("'devices' argument to pmap must be non-empty, or None.") - src = fun_sourceinfo(fun) - signature = api_util.fun_signature(fun) - - dbg = debug_info('pmap', src, signature, args, kwargs, - static_broadcasted_tuple, ()) + dbg = tracing_debug_info( + 'pmap', fun, args, kwargs, + static_argnums=static_broadcasted_tuple) f = lu.wrap_init(fun) if static_broadcasted_tuple: diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 930011fe2..d1ea2396d 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -116,6 +116,7 @@ def flattened_fun_in_tree( (args, store, f is flatten_fun.args[0]) for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens) except ValueError: + # When `fn` is not the result of flatten_fun or flatten_fun_nokwargs return None else: return in_tree, lambda: out_tree_store.val, has_kwargs # type: ignore[union-attr] @@ -589,7 +590,7 @@ def _dtype(x): def api_hook(fun, tag: str): return fun - +# TODO(necula): replace usage with tracing_debug_info def debug_info( traced_for: str, fun_src_info: str | None, fun_signature: inspect.Signature | None, @@ -598,12 +599,36 @@ def debug_info( static_argnames: tuple[str, ...] ) -> TracingDebugInfo | None: """Try to build trace-time debug info for fun when applied to args/kwargs.""" - arg_names = _arg_names(fun_signature, args, kwargs, static_argnums, + arg_names = _non_static_arg_names(fun_signature, args, kwargs, static_argnums, static_argnames) if arg_names is None: return None return TracingDebugInfo(traced_for, fun_src_info, arg_names, None) + +def tracing_debug_info( + traced_for: str, + fun: Callable, + args: Sequence[Any], + kwargs: dict[str, Any], + *, + static_argnums: tuple[int, ...] = (), + static_argnames: tuple[str, ...] = (), + result_paths_thunk: Callable[[], tuple[str, ...]] | None = None, + # TODO(necula): check if we really need this, e.g., to speed up tracing. + sourceinfo: str | None = None, + signature: inspect.Signature | None = None, +) -> TracingDebugInfo: + if sourceinfo is None: + sourceinfo = fun_sourceinfo(fun) + if signature is None: + signature = fun_signature(fun) + arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums, + static_argnames) + # TODO(necula): remove type: ignore once we fix arg_names to never be None + return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) # type: ignore + + def fun_signature(fun: Callable) -> inspect.Signature | None: try: return inspect.signature(fun) @@ -631,8 +656,11 @@ def fun_sourceinfo(fun: Callable) -> str | None: except AttributeError: return None -def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, - ) -> tuple[str, ...] | None: +def _non_static_arg_names(fn_signature: inspect.Signature | None, + args: Sequence[Any], kwargs: dict[str, Any], + static_argnums: Sequence[int], + static_argnames: Sequence[str], + ) -> tuple[str | None, ...] | None: if fn_signature is None: return None static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) @@ -665,7 +693,7 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr, result_paths = trace_debug.result_paths_thunk() # type: ignore debug_info = core.JaxprDebugInfo( trace_debug.traced_for, trace_debug.func_src_info, - trace_debug.arg_names, tuple(result_paths)) + trace_debug.arg_names, tuple(result_paths)) # type: ignore return jaxpr.replace(debug_info=debug_info) def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None, diff --git a/jax/_src/core.py b/jax/_src/core.py index e83ce6546..df061d5f8 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -140,8 +140,8 @@ class Jaxpr: self._eqns = list(eqns) self._effects = effects self._debug_info = debug_info - assert (not debug_info or len(debug_info.arg_names) == len(invars) and - len(debug_info.result_paths) == len(outvars)) + assert (not debug_info or debug_info.arg_names is None or len(debug_info.arg_names) == len(invars)), (debug_info, invars) + assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) def __str__(self): return str(self.pretty_print()) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index ad1285219..4470cfe51 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -30,8 +30,8 @@ from jax._src import traceback_util from jax._src.ad_util import ( stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import ( - argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature, - _arg_names) + argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature, + _non_static_arg_names) from jax._src.errors import UnexpectedTracerError from jax._src.state.types import AbstractRef from jax._src.interpreters import ad @@ -636,7 +636,7 @@ def _check_for_aliased_refs(f, nondiff_argnums, args): for i, x in enumerate(leaves): if (isinstance((a := core.get_aval(x)), AbstractRef) and (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): - arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ()) + arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ()) if arg_names is None: arg_names = [f'flat index {j}' for j in range(len(leaves))] raise ValueError( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 43a771cd2..c0d20d1b2 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -35,8 +35,7 @@ from jax._src import profiler from jax._src import source_info_util from jax._src import compute_on from jax._src import xla_metadata as xla_metadata_lib -from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, - fun_sourceinfo) +from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs) from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, Var, DropVar, Atom, @@ -44,9 +43,9 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.state.types import AbstractRef +from jax._src import tree_util from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten, - tree_flatten, tree_structure, generate_key_paths, - keystr) + tree_flatten, tree_structure) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, subs_list) @@ -579,9 +578,9 @@ def trace_to_jaxpr_nounits( # TODO(mattjj): superfluous wrapper...? @lu.transformation2 def trace_to_subjaxpr_nounits( - f, + f: Callable, trace: JaxprTrace, - instantiate: bool | Sequence[bool], + instantiate: Sequence[bool] | bool, in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( @@ -607,7 +606,9 @@ def trace_to_subjaxpr_nounits2( del out_tracers return jaxpr, (out_pvals, out_consts, env) -def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals): +def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, + instantiate: Sequence[bool] | bool, + in_pvals: Sequence[PartialVal]): in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] @@ -1903,7 +1904,8 @@ class DynamicJaxprTrace(core.Trace): self.frame.add_eqn(eqn) return out_tracers if primitive.multiple_results else out_tracers.pop() - def process_call(self, call_primitive, f, explicit_tracers, params): + def process_call(self, call_primitive, f: lu.WrappedFun, + explicit_tracers, params): if f.in_type is None: f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) @@ -1915,7 +1917,7 @@ class DynamicJaxprTrace(core.Trace): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) source_info = source_info_util.current() - out_tracers = [] + out_tracers: list[Tracer] = [] for aval, _ in out_type: if type(aval) is DShapedArray: shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else @@ -2110,35 +2112,39 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): # Callers should be using linear_util.debug_info instead! def tracing_debug_info( fn: Callable, - in_tree: PyTreeDef | None, - out_tree_thunk: Callable[[], PyTreeDef] | None, + in_tree: PyTreeDef, + out_tree_thunk: Callable[[], PyTreeDef], has_kwargs: bool, traced_for: str ) -> lu.TracingDebugInfo: - src_info = fun_sourceinfo(fn) - arg_names: tuple[str | None, ...] | None + # TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead + # We just have to make sure we grad the debugging information when we have + # the unflattened args + # TODO(necula): in general we can just pretend the leaves are booleans, but + # when we use custom pytrees, the flattening functions may check the type + # of the argument try: dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore - args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) - ba = api_util.fun_signature(fn).bind(*args, **kwargs) # type: ignore - arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items() - for path, _ in generate_key_paths(dummy)) except: - arg_names = None # TODO(necula): we should not need this - def result_paths(): - try: - out_tree = out_tree_thunk() - dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves) - except: - return None # TODO(necula): this does not seem to be needed - return tuple(path for path, _ in generate_key_paths(dummy_result)) - # TODO(necula): clean up the type: ignore below - return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore[arg-type] + # TODO(necula): remove this catch-all. Repro in batching_test:test_basic_jit + dummy_args = ([False], {}) if has_kwargs else [False] + args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) # type: ignore + def res_paths_thunk() -> tuple[str, ...]: + out_tree = out_tree_thunk() + dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves) + return tuple(tree_util.keystr(path) + for path, _ in tree_util.generate_key_paths(dummy_result)) + return api_util.tracing_debug_info(traced_for, fn, args, kwargs, + result_paths_thunk=res_paths_thunk) def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo: - in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False) - return tracing_debug_info(fn.f, in_tree, out_tree, has_kws, traced_for) - + fn_trees = flattened_fun_in_tree(fn) + if fn_trees is None: + # TODO(necula): eliminate this branch + return lu.TracingDebugInfo(traced_for, api_util.fun_sourceinfo(fn.f), + (None,), None) + in_tree, out_tree_thunk, has_kws = fn_trees + return tracing_debug_info(fn.f, in_tree, out_tree_thunk, has_kws, traced_for) @profiler.annotate_function def trace_to_jaxpr_dynamic( @@ -2178,7 +2184,7 @@ def _check_no_returned_refs( raise ValueError( f"function returned a mutable array reference of type {a.str_short()}, " "but mutable array references cannot be returned.") - loc = (f' at output tree path {keystr(ls[i])}' # type: ignore + loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore if (dbg.result_paths_thunk and (ls := dbg.result_paths_thunk()) and ls[i]) else '') @@ -2190,7 +2196,7 @@ def _check_no_returned_refs( origin_info = ('\n\nThe returned mutable array was created on line ' f'{source_info_util.summarize(eqn.source_info)}.') elif v in frame.invars: - arg_name = dbg.arg_names[frame.invars.index(v)] + arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore origin_info = ('\n\nThe returned mutable array was passed in as the ' f'argument {arg_name}.') else: diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index a9bcfd24d..17116e857 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -139,7 +139,7 @@ def switch(index, branches: Sequence[Callable], *operands, ops_avals = tuple(map(core.get_aval, ops)) if config.mutable_array_checks.value: - dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch') + dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch') # type: ignore _check_no_aliased_ref_args(dbg, ops_avals, ops) jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( @@ -238,7 +238,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, ops_avals = tuple(map(core.get_aval, ops)) if config.mutable_array_checks.value: - dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond') + dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond') # type: ignore _check_no_aliased_ref_args(dbg, ops_avals, ops) jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( (true_fun, false_fun), ops_tree, ops_avals, 'cond') diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 3017e734a..8e18d1455 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -275,7 +275,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]], if config.mutable_array_checks.value: in_flat, in_tree = tree_flatten((init, xs)) - dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan') + dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan') # type: ignore in_avals = tuple(_map(core.get_aval, in_flat)) _check_no_aliased_ref_args(dbg, in_avals, in_flat) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 9de9809bf..919cd90c3 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -254,11 +254,19 @@ def fun_name(f): return str(f) class TracingDebugInfo(NamedTuple): - # Packages up trace/staging-time debug info about a func and its parameters, - # formed just before staging to a jaxpr and read in trace-time error messages. + """Tracing-time debugging info about a func and its arguments. + + Formed just before staging to a jaxpr and read in trace-time error messages. + """ traced_for: str # e.g. 'jit', 'scan', etc func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}' - arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... ) + + # The paths of the flattened non-static argnames, + # e.g. ('x', 'dict_arg["a"]', ... ). + # Uses `None` for the args that do not correspond to user-named arguments, + # e.g., tangent args in jax.jvp. + arg_names: tuple[str | None, ...] + # e.g. ('[0]', '[1]', ...) result_paths_thunk: Callable[[], tuple[str, ...]] | None diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d0fab8613..87a63db39 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1810,13 +1810,11 @@ def pallas_call( kernel_fun_sig = api_util.fun_signature(kernel) arg_names = None if kernel_fun_sig: - kernel_debug_info = api_util.debug_info( + kernel_debug_info = api_util.tracing_debug_info( "pallas_call kernel", - kernel_src_info, - kernel_fun_sig, - [1] * len(kernel_fun_sig.parameters), {}, (), ()) - if kernel_debug_info: - arg_names = kernel_debug_info.arg_names + kernel, + [1] * len(kernel_fun_sig.parameters), {}) + arg_names = kernel_debug_info.arg_names del kernel_debug_info in_origins = tuple(in_path_to_input_origin(p, arg_names) for p in in_paths) @@ -1909,6 +1907,10 @@ def in_path_to_input_origin( if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len( arg_names ): + if arg_names[arg_idx.idx] is None: + # TODO(necula): when is this needed? + # Repro: pallas_test:test_with_input_output_aliasing + return f"args{tree_util.keystr(in_path)}" return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path)) else: return f"args{tree_util.keystr(tuple(in_path))}" diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 7b2ae2b79..7c9449062 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb from jax._src.api_util import ( argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, donation_vector, check_callable, resolve_argnums, - argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info, + argnames_partial_except, debug_info, tracing_debug_info, result_paths, add_jaxpr_debug_info, hoist_obj_attrs, _check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) from jax._src.interpreters import partial_eval as pe @@ -176,7 +176,7 @@ class PjitInfo(NamedTuple): return self is other -def _python_pjit_helper(fun, jit_info, *args, **kwargs): +def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): p, args_flat = _infer_params(fun, jit_info, args, kwargs) for arg in args_flat: @@ -568,6 +568,14 @@ def _infer_params_impl( dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs, ji.static_argnums, ji.static_argnames) + # TODO(necula): replace the above with below. + # haiku/_src/integration:hk_transforms_test fails + # dbg = tracing_debug_info('jit', fun, args, kwargs, + # static_argnums=ji.static_argnums, + # static_argnames=ji.static_argnames, + # TODO(necula): do we really need this, e.g., for tracing speed + # sourceinfo = ji.fun_sourceinfo, + # signature = ji.fun_signature) f = lu.wrap_init(fun) f, res_paths = result_paths(f) f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) @@ -732,8 +740,12 @@ def _infer_params( signature, dynargs = jax_jit.parse_arguments( args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, ji.static_argnames, tree_util.default_registry) - dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs, - ji.static_argnums, ji.static_argnames) + dbg = tracing_debug_info('jit', fun, args, kwargs, + static_argnums=ji.static_argnums, + static_argnames=ji.static_argnames, + # TODO(necula): do we really need this, e.g., for tracing speed + sourceinfo=ji.fun_sourceinfo, + signature=ji.fun_signature) avals = _infer_input_type(fun, dbg, dynargs) entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env) if entry.pjit_params is None: @@ -744,7 +756,9 @@ def _infer_params( entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs -def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]: +def _infer_input_type(fun: Callable, + dbg: lu.TracingDebugInfo | None, + explicit_args) -> tuple[core.AbstractValue, ...]: avals = [] try: for i, x in enumerate(explicit_args): @@ -1672,7 +1686,7 @@ def _pjit_call_impl_python( if compiled._auto_spmd_lowering and config.enable_checks.value: pxla.check_array_xla_sharding_layout_match( args, compiled._in_shardings, compiled._in_layouts, - jaxpr.jaxpr.tracing_debug_info, compiled._kept_var_idx) + jaxpr.jaxpr._debug_info, compiled._kept_var_idx) if config.distributed_debug.value: # Defensively only perform fingerprint logic if debug logging is enabled # NOTE(skyewm): I didn't benchmark this diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 95792bd00..32b56a332 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -985,15 +985,18 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue], return new_invals, out_vals def initial_style_jaxpr( - fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue] + fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], + dbg: api_util.TracingDebugInfo, ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]: - return _initial_style_jaxpr(fun, in_tree, tuple(in_avals)) + return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg) @weakref_lru_cache -def _initial_style_jaxpr(fun, in_tree, in_avals): +def _initial_style_jaxpr(fun: Callable, + in_tree: api_util.PyTreeDef, + in_avals: Sequence[core.AbstractValue], + debug: api_util.TracingDebugInfo): fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), tree_util.treedef_tuple((in_tree,))) - debug = pe.tracing_debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state') jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) return jaxpr, consts, out_tree_thunk() @@ -1001,10 +1004,11 @@ def _initial_style_jaxpr(fun, in_tree, in_avals): T = TypeVar('T') def run_state(f: Callable[..., None]) -> Callable[[T], T]: def wrapped(args): + dbg = api_util.tracing_debug_info("run_state", f, (args,), {}) flat_args, in_tree = tree_util.tree_flatten(args) ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) # There may be some uninitialized values here in ref_args. - jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals) + jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg) jaxpr = hoist_consts_to_refs(jaxpr_) which_linear = (False,) * (len(consts) + len(ref_args)) refs_is_initialized = tuple(r is not uninitialized for r in ref_args) @@ -1020,9 +1024,10 @@ def run_state(f: Callable[..., None]) -> Callable[[T], T]: def run_state_reference(f: Callable[..., None]): def wrapped(args): + dbg = api_util.tracing_debug_info("run_state", f, (args,), {}) flat_args, in_tree = tree_util.tree_flatten(args) ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) - jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals) + jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg) jaxpr = hoist_consts_to_refs(jaxpr_) discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) diff --git a/tests/state_test.py b/tests/state_test.py index 0e60dc287..0fc37ba47 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -1496,8 +1496,23 @@ class RunStateTest(jtu.JaxTestCase): def test_can_stage_run_state(self): def f(x): return run_state(lambda _: None)(x) + jaxpr = jax.make_jaxpr(f)(2) + self.assertIsNotNone(jaxpr.jaxpr.debug_info) + self.assertIsNotNone(jaxpr.jaxpr.debug_info.func_src_info) + + def test_can_stage_run_state_leaked_tracer_error(self): + leaks = [] + def f(x): + def my_fun(x): + leaks.append(x) + return None + return run_state(my_fun)(x) _ = jax.make_jaxpr(f)(2) + with self.assertRaisesRegex(jax.errors.UnexpectedTracerError, + "The function being traced when the value leaked was .*my_fun"): + jax.jit(lambda _: leaks[0])(1) + def test_nested_run_state_captures_effects(self): def f(x): def body(x_ref):