From abcaec70811d1a12bcc1717f2aef2afd488ce84f Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 24 Jan 2025 12:53:51 +0200 Subject: [PATCH] [better_errors] Add debug info to the Jaxprs formed for AD Following #26078 , we add debug info to more calls of lu.wrap_init. --- jax/BUILD | 1 + jax/_src/ad_checkpoint.py | 10 +- jax/_src/api.py | 22 ++- jax/_src/api_util.py | 7 +- jax/_src/checkify.py | 7 +- jax/_src/core.py | 5 +- jax/_src/custom_derivatives.py | 131 ++++++++----- jax/_src/interpreters/ad.py | 19 +- jax/_src/interpreters/mlir.py | 4 +- jax/_src/interpreters/partial_eval.py | 59 +++--- jax/_src/lax/control_flow/common.py | 3 +- jax/_src/lax/control_flow/conditionals.py | 37 ++-- jax/_src/lax/control_flow/for_loop.py | 35 ++-- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/lax/control_flow/solves.py | 9 +- jax/_src/lax/lax.py | 10 +- jax/_src/linear_util.py | 12 +- jax/_src/pjit.py | 9 +- jax/_src/state/discharge.py | 3 +- jax/experimental/shard_map.py | 27 +-- tests/api_test.py | 22 ++- tests/debug_info_test.py | 215 ++++++++++++++++++++-- 22 files changed, 480 insertions(+), 169 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 657dbf179..1225993ad 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -564,6 +564,7 @@ pytype_strict_library( srcs = ["_src/interpreters/mlir.py"], deps = [ ":ad_util", + ":api_util", ":config", ":core", ":dtypes", diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 9a6c78fa5..61576bb50 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -682,12 +682,13 @@ def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool], return _transpose_jaxpr(jaxpr, tuple(in_linear), tuple(out_zeros)) @weakref_lru_cache -def _transpose_jaxpr(jaxpr, in_lin, out_zeros): +def _transpose_jaxpr(jaxpr: core.ClosedJaxpr, + in_lin: Sequence[bool], + out_zeros: Sequence[bool]): in_avals = ([a for a, lin in zip(jaxpr.in_avals, in_lin ) if not lin] + [a for a, zero in zip(jaxpr.out_avals, out_zeros) if not zero]) cell = lambda: None - @lu.wrap_init def transposed(*args_flat): ins_flat, out_cts_flat = split_list(args_flat, [len(in_lin) - sum(in_lin)]) @@ -715,7 +716,10 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros): in_cts_nz, _ = partition_list(in_zeros, in_cts) return in_cts_nz - transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals) + transposed_wrapped = lu.wrap_init(transposed, + debug_info=jaxpr.jaxpr.debug_info) + transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic( + transposed_wrapped, in_avals) transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts) return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error diff --git a/jax/_src/api.py b/jax/_src/api.py index 224ad2160..581b2b512 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -983,7 +983,7 @@ def vmap(fun: F, "to the positional arguments passed to the function, " f"but got {len(in_axes)=}, {len(args)=}") args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable) - f = lu.wrap_init(fun) + f = lu.wrap_init(fun, debug_info=debug_info("vmap", fun, args, kwargs)) flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree) in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True) axis_size_ = (axis_size if axis_size is not None else @@ -1715,15 +1715,15 @@ def jvp( 0.19900084 """ check_callable(fun) - return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux) - -def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): - """Variant of jvp() that takes an lu.WrappedFun.""" if (not isinstance(primals, (tuple, list)) or not isinstance(tangents, (tuple, list))): raise TypeError("primal and tangent arguments to jax.jvp must be tuples or lists; " f"found {type(primals).__name__} and {type(tangents).__name__}.") + return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})), + primals, tangents, has_aux=has_aux) +def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): + """Variant of jvp() that takes an lu.WrappedFun.""" ps_flat, tree_def = tree_flatten(primals) ts_flat, tree_def_2 = tree_flatten(tangents) if tree_def != tree_def_2: @@ -1835,7 +1835,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False -6.676704 """ check_callable(fun) - f = lu.wrap_init(fun) + f = lu.wrap_init(fun, debug_info=debug_info("linearize", fun, primals, {})) primals_flat, in_tree = tree_flatten(primals) if has_aux: jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree) @@ -1983,8 +1983,9 @@ def vjp( raise NotImplementedError("reduce_axes argument to vjp is deprecated") del reduce_axes check_callable(fun) - return _vjp( - lu.wrap_init(fun), *primals, has_aux=has_aux) + wrapped_fun = lu.wrap_init(fun, + debug_info=debug_info("vjp", fun, primals, {})) + return _vjp(wrapped_fun, *primals, has_aux=has_aux) def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): """Variant of vjp() that takes an lu.WrappedFun.""" @@ -2049,7 +2050,10 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: raise NotImplementedError("reduce_axes argument to transpose is deprecated") del reduce_axes primals_flat, in_tree = tree_flatten(primals) - flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) + flat_fun, out_tree = flatten_fun_nokwargs( + lu.wrap_init(fun, + debug_info=debug_info("linear_transpose", fun, primals, {})), + in_tree) in_avals = map(shaped_abstractify, primals_flat) in_dtypes = map(dtypes.dtype, in_avals) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 9e4478245..906c40542 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -67,7 +67,8 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]: return tuple(map(_ensure_str, x)) @lu.transformation_with_aux2 -def flatten_fun(f, store, in_tree, *args_flat): +def flatten_fun(f: Callable, store: lu.Store, + in_tree: PyTreeDef, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) ans = f(*py_args, **py_kwargs) ans, out_tree = tree_flatten(ans) @@ -587,8 +588,8 @@ def debug_info( args: Sequence[Any], kwargs: dict[str, Any], *, - static_argnums: tuple[int, ...] = (), - static_argnames: tuple[str, ...] = (), + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = (), result_paths_thunk: Callable[[], tuple[str, ...]] | None = None, # TODO(necula): check if we really need this, e.g., to speed up tracing? sourceinfo: str | None = None, diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 6480fb981..9eea81b9d 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -361,8 +361,9 @@ def default_checkify_rule(primitive: core.Primitive, error: Error, else: jaxpr, consts = call_jaxpr, () consts_ = tuple(HashableWrapper(c) for c in consts) - partial_checkify = lu.hashable_partial(lu.wrap_init( - checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree) + partial_checkify = lu.hashable_partial( + lu.wrap_init(checkify_jaxpr_flat_hashable, debug_info=jaxpr.debug_info), + jaxpr, consts_, enabled_errors, err_tree) partial_checkify, metadata = _flatten_and_get_error_metadata_thunk( partial_checkify) @@ -746,7 +747,7 @@ def jaxpr_to_checkify_jaxpr( checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr, jaxpr.consts, enabled_errors, err_tree) - fun = lu.wrap_init(checkify_jaxpr_partial) + fun = lu.wrap_init(checkify_jaxpr_partial, debug_info=jaxpr.jaxpr.debug_info) fun, metadata = _flatten_and_get_error_metadata_thunk(fun) new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals) diff --git a/jax/_src/core.py b/jax/_src/core.py index 2bd1de761..2aefe1544 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2416,8 +2416,9 @@ call_p.def_impl(call_impl) class ClosedCallPrimitive(CallPrimitive): def get_bind_params(self, params): new_params = dict(params) - jaxpr = new_params.pop('call_jaxpr') - subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) + jaxpr: ClosedJaxpr = new_params.pop('call_jaxpr') + subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts), + debug_info=jaxpr.jaxpr.debug_info) return [subfun], new_params closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call') diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 6e4864fa8..c8f8cc7a3 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -31,7 +31,7 @@ from jax._src.ad_util import ( stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature, - _non_static_arg_names, prepend_static_args) + _non_static_arg_names, prepend_static_args, debug_info) from jax._src.errors import UnexpectedTracerError from jax._src.state.types import AbstractRef from jax._src.interpreters import ad @@ -44,7 +44,7 @@ from jax._src.lax import lax from jax._src.tree_util import ( tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, register_pytree_node_class, tree_leaves, tree_flatten_with_path, - tree_leaves_with_path, keystr, treedef_children) + tree_leaves_with_path, keystr, treedef_children, PyTreeDef) from jax._src.util import (cache, safe_zip, safe_map, split_list, unzip2, weakref_lru_cache) @@ -78,7 +78,9 @@ _stop_gradient = partial( # like the api_util.py function, but also grabs output avals for error checking @lu.transformation_with_aux2 -def _flatten_fun_nokwargs(f, store, in_tree, *args_flat): +def _flatten_fun_nokwargs(f: Callable, + store: lu.Store, in_tree: PyTreeDef, + *args_flat): py_args = tree_unflatten(in_tree, args_flat) ans = f(*py_args) ans_flat, ans_tree = tree_flatten(ans) @@ -204,7 +206,7 @@ class custom_jvp(Generic[ReturnValue]): *jvps: a sequence of functions, one for each positional argument of the :class:`~jax.custom_jvp` function. Each function takes as arguments the tangent value for the corresponding primal input, the primal - output, and the ßprimal inputs. See the example below. + output, and the primal inputs. See the example below. Returns: None. @@ -239,28 +241,40 @@ class custom_jvp(Generic[ReturnValue]): @traceback_util.api_boundary def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation - primal_name = getattr(self.fun, '__name__', str(self.fun)) + debug = debug_info("custom_jvp fun", self.fun, args, kwargs, + static_argnums=self.nondiff_argnums) + primal_name = debug.func_name if not self.jvp: msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp." raise AttributeError(msg) - jvp_name = getattr(self.jvp, '__name__', str(self.jvp)) + args = resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple(_stop_gradient(x) if i in nondiff_argnums else x for i, x in enumerate(args)) diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] - f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args, + f_, dyn_args = argnums_partial(lu.wrap_init(self.fun, debug_info=debug), + diff_argnums, args, require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] - jvp = prepend_static_args(lu.wrap_init(self.jvp), static_args) + diff_args = [args[i] for i, a in enumerate(args) if i not in self.nondiff_argnums] + debug_jvp = debug_info("custom_jvp jvp", self.jvp, + (*static_args, diff_args, diff_args), + {}, + static_argnums=tuple(range(len(static_args)))) + jvp = prepend_static_args(lu.wrap_init(self.jvp, + debug_info=debug_jvp), static_args) else: - f_, dyn_args = lu.wrap_init(self.fun), args - jvp = lu.wrap_init(self.jvp) + f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args + debug_jvp = debug_info("custom_jvp jvp", self.jvp, + (args, args), + {}) + jvp = lu.wrap_init(self.jvp, debug_info=debug_jvp) args_flat, in_tree = tree_flatten(dyn_args) flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree) - flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, jvp_name, in_tree, - out_type1) + flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, debug_jvp.func_name, + in_tree, out_type1) out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat, symbolic_zeros=self.symbolic_zeros) _, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2) @@ -611,15 +625,20 @@ class custom_vjp(Generic[ReturnValue]): @traceback_util.api_boundary def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation - primal_name = getattr(self.fun, '__name__', str(self.fun)) + debug_fun = debug_info("custom_vjp fun", self.fun, args, kwargs, + static_argnums=self.nondiff_argnums) if not self.fwd or not self.bwd: - msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." + msg = f"No VJP defined for custom_vjp function {debug_fun.func_name} using defvjp." raise AttributeError(msg) - fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) args = resolve_kwargs(self.fun, args, kwargs) + debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs, + static_argnums=self.nondiff_argnums) + # TODO(necula): figure out how to construct the debug_bwd args + debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {}) if self.optimize_remat: fwd = optimize_remat_of_custom_vjp_fwd( - self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums, + self.fun, debug_fun, self.fwd, debug_fwd, + nondiff_argnums=self.nondiff_argnums, symbolic_zeros=self.symbolic_zeros) else: fwd = self.fwd @@ -633,23 +652,27 @@ class custom_vjp(Generic[ReturnValue]): for i in self.nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums = set(self.nondiff_argnums) dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] - f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, - args, require_static_args_hashable=False) + f_, dyn_args = argnums_partial( + lu.wrap_init(self.fun, debug_info=debug_fun), dyn_argnums, + args, require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] - fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args, - require_static_args_hashable=False) - bwd = prepend_static_args(lu.wrap_init(self.bwd), static_args) + fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd), + dyn_argnums, args, + require_static_args_hashable=False) + bwd = prepend_static_args(lu.wrap_init(self.bwd, debug_info=debug_bwd), + static_args) else: - f_, dyn_args = lu.wrap_init(self.fun), args - fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd) + f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug_fun), args + fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd) + bwd = lu.wrap_init(self.bwd, debug_info=debug_bwd) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.get_aval(x) for x in args_flat] if config.mutable_array_checks.value: - f_ = _check_primal_refs(f_, self.nondiff_argnums) + f_ = _check_primal_refs(f_, self.nondiff_argnums, f_.debug_info) flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd( - fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name, - fwd_name, in_tree, out_type) + fwd_, self.nondiff_argnums, self.symbolic_zeros, debug_fun, + debug_fwd, in_tree, out_type) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees, @@ -658,19 +681,27 @@ class custom_vjp(Generic[ReturnValue]): return tree_unflatten(out_tree, out_flat) @lu.transformation2 -def _check_primal_refs(f, nondiff_argnums, *args): - _check_for_aliased_refs(f, nondiff_argnums, args) +def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int], + debug_info: core.DebugInfo | None, *args): + _check_for_aliased_refs(f, nondiff_argnums, debug_info, args) out = f(*args) _check_for_returned_refs(f, out, 'primal') return out -def _check_for_aliased_refs(f, nondiff_argnums, args): +def _check_for_aliased_refs(f: Callable, + nondiff_argnums: Sequence[int], + debug: core.DebugInfo | None, + args): leaves = tree_leaves(args) refs: dict[int, int] = {} for i, x in enumerate(leaves): if (isinstance((a := core.get_aval(x)), AbstractRef) and (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): - arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ()) + if debug is not None: + arg_names = debug.safe_arg_names(len(leaves)) + else: + # TODO(necula): drop this branch + arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ()) if arg_names is None: arg_names = [f'flat index {j}' for j in range(len(leaves))] raise ValueError( @@ -725,18 +756,24 @@ def _check_for_tracers(x): raise UnexpectedTracerError(msg) @partial(lu.transformation_with_aux2, use_eq_store=True) -def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name, - fwd_name, in_tree, maybe_out_type, *args): +def _flatten_fwd(f: Callable, store: lu.EqualStore, + nondiff_argnums: Sequence[int], + symbolic_zeros: bool, + debug_primal: core.DebugInfo | None, + debug_fwd: core.DebugInfo | None, + in_tree: PyTreeDef, maybe_out_type, *args): + primal_name = debug_primal.func_name if debug_primal else str(f) + fwd_name = debug_fwd.func_name if debug_fwd else "" if symbolic_zeros: - args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])] + args = tuple(CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])) else: args = args[::2] py_args = tree_unflatten(in_tree, args) if config.mutable_array_checks.value: - _check_for_aliased_refs(f, nondiff_argnums, py_args) + _check_for_aliased_refs(f, nondiff_argnums, debug_primal, py_args) pair_out = f(*py_args) if config.mutable_array_checks.value: - _check_for_returned_refs(f, pair_out, 'fwd') + _check_for_returned_refs(f, pair_out, "fwd") if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " "must produce a pair (list or tuple of length two) where the first " @@ -790,7 +827,10 @@ def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name, return (*res, *primals_out) @lu.transformation2 -def _flatten_bwd(f, in_tree, in_avals, out_trees, *args): +def _flatten_bwd(f: Callable, + in_tree: PyTreeDef, + in_avals: Sequence[core.AbstractValue], + out_trees: Callable[[], Sequence[PyTreeDef]], *args): out_tree, res_tree = out_trees() assert len(args) == res_tree.num_leaves + out_tree.num_leaves res, cts_out = split_list(args, [res_tree.num_leaves]) @@ -1494,7 +1534,9 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr") # simpler, but it would be worth revisiting this. def optimize_remat_of_custom_vjp_fwd( fun: Callable[..., ReturnValue], + debug_fun: core.DebugInfo | None, fwd: Callable[..., tuple[ReturnValue, Any]], + debug_fwd: core.DebugInfo | None, nondiff_argnums: Sequence[int] = (), symbolic_zeros: bool = False, ) -> Callable[..., tuple[ReturnValue, Any]]: @@ -1507,8 +1549,7 @@ def optimize_remat_of_custom_vjp_fwd( def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: # TODO(dfm): This initial logic is duplicated from custom_vjp.__call__ # above and it would be good to consolidate it. - primal_name = getattr(fun, "__name__", str(fun)) - fwd_name = getattr(fwd, "__name__", str(fwd)) + fwd_name = debug_fwd.func_name if debug_fwd else str(fwd) # Note: we use `fun` instead of `fwd` here for consistency with # custom_vjp.__call__ above. args = resolve_kwargs(fun, args, kwargs) @@ -1516,17 +1557,19 @@ def optimize_remat_of_custom_vjp_fwd( for i in nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums_ = set(nondiff_argnums) dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_] - f_, dyn_args = argnums_partial(lu.wrap_init(fun), dyn_argnums, - args, require_static_args_hashable=False) - fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args, + f_, dyn_args = argnums_partial(lu.wrap_init(fun, debug_info=debug_fun), + dyn_argnums, + args, require_static_args_hashable=False) + fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd), + dyn_argnums, args, require_static_args_hashable=False) else: - f_, dyn_args = lu.wrap_init(fun), args - fwd_ = lu.wrap_init(fwd) + f_, dyn_args = lu.wrap_init(fun, debug_info=debug_fun), args + fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd) args_flat, in_tree = tree_flatten(dyn_args) flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False, - primal_name, fwd_name, in_tree, out_type) + debug_fun, debug_fwd, in_tree, out_type) flat_fwd = _fix_fwd_args(flat_fwd) in_avals = [core.get_aval(x) for x in args_flat] diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 6e622b9c4..8d0f0e7ca 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -68,7 +68,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, return jvpfun(fun, instantiate, transform_stack), aux @lu.transformation2 -def jvpfun(f, instantiate, transform_stack, primals, tangents): +def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] @@ -106,7 +106,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params): return tuple(consts) + tuple(out_primals) @lu.transformation2 -def jvp_subtrace(f, tag, primals, tangents): +def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) in_tracers = [maybe_jvp_tracer(trace, x, t) @@ -778,7 +778,8 @@ def linearize_from_jvp(jvp, multiple_results, nonzeros, out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] - jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers) + # TODO(necula): pass debug_info here + jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, None) def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] @@ -932,13 +933,15 @@ def traceable(f, store, in_tree, *primals_and_tangents): return out_flat -def call_transpose(primitive, params, call_jaxpr, args, ct, _): +def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): if isinstance(call_jaxpr, core.ClosedJaxpr): call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts else: consts = () all_args, in_tree_def = tree_flatten((consts, args, ct)) - fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, False) + fun = lu.hashable_partial(lu.wrap_init(backward_pass, + debug_info=call_jaxpr.debug_info), + call_jaxpr, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) update_params = call_transpose_param_updaters.get(primitive) if update_params: @@ -950,7 +953,7 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _): res_invars, _ = partition_list(which_lin, call_jaxpr.invars) new_invars = [*res_invars, *call_jaxpr.outvars] dbidx_map = {v: core.DBIdx(i) for i, v in enumerate(new_invars)} - in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape)) + in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape)) # type: ignore[arg-type] if type(v.aval) is core.DShapedArray else v.aval, True) for v in new_invars] fun = lu.annotate(fun, tuple(in_type)) out_flat = primitive.bind(fun, *all_args, **params) @@ -1027,8 +1030,8 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], instantiate: Sequence[bool]): assert len(jaxpr.in_avals) == len(nonzeros) - debug_info = jaxpr.jaxpr.debug_info - f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info) + f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), + debug_info=jaxpr.jaxpr.debug_info) f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index b2c0792e5..9500e5c1f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -34,6 +34,7 @@ import warnings import numpy as np from jax._src import ad_util +from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dtypes @@ -2156,7 +2157,8 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: as `avals_out`.""" def f_lowered(ctx: LoweringRuleContext, *args, **params): f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) - wrapped_fun = lu.wrap_init(f, params) + wrapped_fun = lu.wrap_init(f, params, + debug_info=api_util.debug_info("lower_fun", fun, args, params)) manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else ctx.jaxpr_eqn_ctx.manager) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 98c5c97a6..721386565 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -236,7 +236,7 @@ class JaxprTrace(Trace['JaxprTracer']): params, effects, source) return out_tracer - def process_call(self, primitive, f, tracers, params): + def process_call(self, primitive, f: lu.WrappedFun, tracers, params): tracers = map(self.to_jaxpr_tracer, tracers) rule = call_partial_eval_rules.get(primitive) if rule: @@ -254,7 +254,7 @@ class JaxprTrace(Trace['JaxprTracer']): # which were unknown to the first call (corresponding to in_avals). # Wrap f to perform the partial evaluation and plumb out aux data. - f_ = trace_to_subjaxpr_nounits_fwd(f, self.tag, False) + f_ = trace_to_subjaxpr_nounits_fwd(f, self.tag, f.debug_info, False) f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), tuple(in_avals)) # Adjust parameters (e.g. donated_invars) for the call to be evaluated now. @@ -330,7 +330,7 @@ class JaxprTrace(Trace['JaxprTracer']): for ax, aval in zip(unk_in_axes, in_avals)] # Wrap f to perform partial evaluation and plumb out aux data. - f = trace_to_subjaxpr_nounits2(f, self.tag, False) + f = trace_to_subjaxpr_nounits2(f, self.tag, f.debug_info, False) f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals_mapped)) # Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk) @@ -429,7 +429,7 @@ class JaxprTrace(Trace['JaxprTracer']): tracers = map(self.instantiate_const_abstracted, tracers) # Because we instantiate all tracers, in_knowns is all False. in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) - f = trace_to_subjaxpr_nounits(f, self, True) + f = trace_to_subjaxpr_nounits(f, self, True, f.debug_info) f, aux = partial_eval_wrapper_nounits(f, (*in_knowns,), (*in_avals,)) with core.set_current_trace(self.parent_trace): out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, @@ -445,7 +445,7 @@ class JaxprTrace(Trace['JaxprTracer']): @_memoize def fwd_jaxpr_thunk(*zeros): fwd_ = _interleave_fun(fwd, zeros) - fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True) + fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True, fwd_.debug_info) fwd_, aux = partial_eval_wrapper_nounits(fwd_, (*in_knowns,), (*in_avals,)) out_flat = fwd_.call_wrapped() out_knowns, out_avals, jaxpr, env = aux() @@ -476,7 +476,10 @@ def partition_pvals( @lu.transformation_with_aux2 def partial_eval_wrapper_nounits( - f, store, in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], + f: Callable, + store: lu.Store, + in_knowns: Sequence[bool], + in_avals: Sequence[AbstractValue], *in_consts: Any): in_avals_, in_consts_ = iter(in_avals), iter(in_consts) in_pvals = [PartialVal.known(next(in_consts_)) if known else @@ -566,7 +569,7 @@ def trace_to_jaxpr_nounits( with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, TraceTag()) with core.ensure_no_leaks(trace): - fun = trace_to_subjaxpr_nounits(fun, trace, instantiate) + fun = trace_to_subjaxpr_nounits(fun, trace, instantiate, fun.debug_info) with core.set_current_trace(trace): jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) assert not env @@ -579,10 +582,11 @@ def trace_to_subjaxpr_nounits( f: Callable, trace: JaxprTrace, instantiate: Sequence[bool] | bool, + debug_info: core.DebugInfo | None, in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( - f, trace, instantiate, in_pvals) + f, trace, instantiate, in_pvals, debug_info) out_pvals = [t.pval for t in out_tracers] del out_tracers return jaxpr, (out_pvals, out_consts, env) @@ -591,6 +595,7 @@ def trace_to_subjaxpr_nounits( def trace_to_subjaxpr_nounits2( f: Callable, tag: TraceTag, + debug_info: core.DebugInfo | None, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert isinstance(tag, TraceTag) @@ -599,14 +604,15 @@ def trace_to_subjaxpr_nounits2( with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( - f, trace, instantiate, in_pvals) + f, trace, instantiate, in_pvals, debug_info) out_pvals = [t.pval for t in out_tracers] del out_tracers return jaxpr, (out_pvals, out_consts, env) def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, instantiate: Sequence[bool] | bool, - in_pvals: Sequence[PartialVal]): + in_pvals: Sequence[PartialVal], + debug_info: core.DebugInfo | None): in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] @@ -623,7 +629,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, out_tracers = [trace.instantiate_const(t) if inst else t for inst, t in zip(instantiate, out_tracers)] out_tracers_ = [t for t in out_tracers if not t.is_known()] - jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_) + jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info) return out_tracers, jaxpr, out_consts, env # The below variant implements an optimization where residuals which are also @@ -631,8 +637,9 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, # TODO(mattjj): update all callers to use this version, delete other version. @lu.transformation2 def trace_to_subjaxpr_nounits_fwd( - f, + f: Callable, tag: TraceTag, + debug_info: core.DebugInfo | None, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals @@ -641,7 +648,7 @@ def trace_to_subjaxpr_nounits_fwd( trace = JaxprTrace(parent_trace, current_name_stack, tag) with core.set_current_trace(trace): out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( - f, trace, instantiate, in_pvals) + f, trace, instantiate, in_pvals, debug_info) out_pvals = [t.pval for t in out_tracers] # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. @@ -660,8 +667,9 @@ def trace_to_subjaxpr_nounits_fwd( # than passed as redundant outputs. @lu.transformation2 def trace_to_subjaxpr_nounits_fwd2( - f, + f: Callable, tag: TraceTag, + debug_info: core.DebugInfo | None, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals @@ -669,7 +677,7 @@ def trace_to_subjaxpr_nounits_fwd2( with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) out_tracers, jaxpr, consts, env = _trace_to_subjaxpr_nounits( - f, trace, instantiate, in_pvals) + f, trace, instantiate, in_pvals, debug_info) out_pvals = [t.pval for t in out_tracers] # Which consts (aka residuals) are just forwarded inputs? Check obj id. @@ -743,7 +751,8 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom], def tracers_to_jaxpr( in_tracers: Sequence[JaxprTracer], - out_tracers: Sequence[JaxprTracer] + out_tracers: Sequence[JaxprTracer], + debug_info: core.DebugInfo | None, ) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]: """Constructs Jaxpr given tracers for inputs and outputs. @@ -819,7 +828,8 @@ def tracers_to_jaxpr( outvars = map(get_atom, out_tracers) # type: ignore[arg-type] jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns) jaxpr = Jaxpr(const_vars, invars, # type: ignore[arg-type] - outvars, eqns, jaxpr_effects) + outvars, eqns, jaxpr_effects, + debug_info) config.enable_checks.value and core.check_jaxpr(jaxpr) # del getvar # needed to avoid cyclic-reference closure, apparently! return jaxpr, const_vals, env_vals @@ -1135,7 +1145,8 @@ def _partial_eval_jaxpr_custom_cached( staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars, outs_staged, staged_eqns) jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars, - outs_staged, staged_eqns, staged_effects) + outs_staged, staged_eqns, staged_effects, + jaxpr.debug_info) config.enable_checks.value and core.check_jaxpr(jaxpr_staged) return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals), @@ -1504,7 +1515,8 @@ def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, - closed_jaxpr.jaxpr.effects) + closed_jaxpr.jaxpr.effects, + closed_jaxpr.jaxpr.debug_info) new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) return new_closed_jaxpr @@ -1666,7 +1678,8 @@ class JaxprStackFrame: set_states(self.attrs_tracked, self.attrs_inits) return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) - def to_jaxpr2(self, out_tracers): + def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], + debug_info: core.DebugInfo | None): # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) constvars, constvals = unzip2(self.constvar_to_val.items()) @@ -1674,10 +1687,10 @@ class JaxprStackFrame: jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars, self.eqns) jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns, - jaxpr_effects) + jaxpr_effects, debug_info) # We can't run check_jaxpr until after we normalize. jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) + jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore[assignment] jaxpr, out_type = _add_implicit_outputs(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, out_type, constvals @@ -2184,7 +2197,7 @@ def trace_to_jaxpr_dynamic2( with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) out_tracers = map(trace.to_jaxpr_tracer, ans) - jaxpr = trace.frame.to_jaxpr2(out_tracers) + jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info) del trace, in_tracers, out_tracers, ans return jaxpr diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 73827f1bc..cecd1cdc5 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -240,7 +240,8 @@ def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=Fa def _prune_zeros(ts): return [t for t in ts if type(t) is not ad_util.Zero] -def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]): +def _make_closed_jaxpr(traceable: lu.WrappedFun, + in_avals: Sequence[core.AbstractValue]): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(traceable, in_avals) return core.ClosedJaxpr(jaxpr, consts) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 511442798..cdd74ab38 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -327,7 +327,7 @@ def _cond_with_per_branch_args(pred, lambda op: false_fun(op[1]), (true_operand, false_operand)) -def _join_cond_effects(branches: Sequence[core.Jaxpr]) -> effects.Effects: +def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects: joined_effects = set() for b in branches: for eff in b.effects: @@ -337,7 +337,8 @@ def _join_cond_effects(branches: Sequence[core.Jaxpr]) -> effects.Effects: joined_effects.add(eff) return joined_effects -def _cond_abstract_eval(*avals, branches, **_): +def _cond_abstract_eval(*avals: core.AbstractValue, + branches: Sequence[core.ClosedJaxpr], **_): joined_effects = _join_cond_effects(branches) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: @@ -614,10 +615,11 @@ def _merge_branch_residuals(branch_res_avals): # This function augments branch outputs to agree with the merged residual # format: each branch is made to return zero-filled values in the places of # residual outputs that it does not populate. -def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr, +def _join_cond_outputs(jaxprs: Sequence[core.ClosedJaxpr], + all_res_avals, res_aval_indices_per_jaxpr, num_non_res_outputs): - def augment_jaxpr(jaxpr, res_indices): - @lu.wrap_init + def augment_jaxpr(jaxpr: core.ClosedJaxpr, + res_indices): def f_aug(*args): outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args) outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs]) @@ -625,19 +627,21 @@ def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr, aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals)) return outs + list(aug_residuals) - return _make_closed_jaxpr(f_aug, jaxpr.in_avals) + wrapped_f_aug = lu.wrap_init(f_aug, debug_info=jaxpr.jaxpr.debug_info) + return _make_closed_jaxpr(wrapped_f_aug, jaxpr.in_avals) return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr)) # This function augments branch inputs to agree with the merged residual format: # each branch is made to accept all residuals, even though it will ignore those # that it does not read. -def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals, +def _join_cond_pe_staged_jaxpr_inputs(jaxprs: Sequence[core.ClosedJaxpr], + all_res_avals, res_aval_indices_per_jaxpr): newvar = core.gensym(suffix='_') all_res_vars = map(newvar, all_res_avals) - def augment_jaxpr(jaxpr, res_indices): + def augment_jaxpr(jaxpr: core.ClosedJaxpr, res_indices) -> core.ClosedJaxpr: num_res = len(res_indices) res_vars = jaxpr.jaxpr.invars[:num_res] non_res_vars = jaxpr.jaxpr.invars[num_res:] @@ -646,9 +650,9 @@ def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals, aug_invars = aug_res_vars + non_res_vars jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns, - jaxpr.jaxpr.effects) - jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) - return jaxpr_aug + jaxpr.jaxpr.effects, + jaxpr.jaxpr.debug_info) + return core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr)) @@ -679,8 +683,7 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, # Finally, update parameters and form the new eqn. new_params = dict(eqn.params, branches=tuple(dce_branches)) - new_effects = core.join_effects(*(b.effects for b in dce_branches)) - new_effects = _join_cond_effects(dce_branches_) + new_effects = _join_cond_effects(dce_branches) new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, [True, *used_inputs]) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], @@ -693,10 +696,10 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, return [True, *used_inputs], new_eqn -def _transpose_cond_jaxpr(jaxpr, num_res): +def _transpose_cond_jaxpr(jaxpr: core.ClosedJaxpr, + num_res: int): res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) - @lu.wrap_init def transposed(*args): res, cts_out = split_list(args, [num_res]) primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] @@ -705,7 +708,9 @@ def _transpose_cond_jaxpr(jaxpr, num_res): _, cts_in = split_list(cts_in, [num_res]) return map(ad.instantiate_zeros, cts_in) - return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals) + return _make_closed_jaxpr(lu.wrap_init(transposed, + debug_info=jaxpr.jaxpr.debug_info), + res_avals + jaxpr.out_avals) def _cond_transpose(cts, *args, branches): index, *ops = args diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 4d1063ff9..c96ff1ba2 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -70,11 +70,13 @@ for_p.multiple_results = True ### Tracing utilities -def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef, - state_avals: Sequence[core.AbstractValue] +def _trace_to_jaxpr_with_refs(f: Callable, state_tree: PyTreeDef, + state_avals: Sequence[core.AbstractValue], + debug_info: core.DebugInfo | None, ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]: f, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree))) + lu.wrap_init(f, debug_info=debug_info), + treedef_tuple((tree_structure(0), state_tree))) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( f, state_avals) return jaxpr, consts, out_tree_thunk() @@ -129,12 +131,13 @@ def for_loop(nsteps: int | Sequence[int], rest_steps, functools.partial(body, i), vals, unroll=unroll) tree_map(lambda ref, val: ref_set(ref, (), val), refs, vals) return for_loop(outer_step, wrapped_body, init_state, unroll=unroll) + dbg = api_util.debug_info("for_loop", body, (0, init_state), {}) nsteps, = nsteps flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64)) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( - body, state_tree, [idx_aval, *state_avals]) + body, state_tree, [idx_aval, *state_avals], dbg) if out_tree != tree_structure(None): raise Exception("`body` should not return anything.") jaxpr = state_utils.hoist_consts_to_refs(jaxpr, index=1) @@ -212,6 +215,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]], tree_map(lambda c_ref, c: ref_set(c_ref, (), c), carry_refs, carry) tree_map(lambda y_ref, y: ref_set(y_ref, (i,), y), ys_refs, y) assert isinstance(length, int) + api_util.save_wrapped_fun_sourceinfo(for_body, f) init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse, unroll=unroll) return init, ys @@ -572,7 +576,6 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn): newvar = core.gensym() resvars = map(newvar, res_avals) - @lu.wrap_init def known(*known_vals): empty_res = map(ad_util.zeros_like_aval, res_avals) jaxpr_known_args = [*known_vals, *empty_res] @@ -581,7 +584,8 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn): reverse=reverse, which_linear=jaxpr_known_which_linear, unroll=unroll) call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( - known, [v.aval for v in known_invars]) + lu.wrap_init(known, debug_info=jaxpr.debug_info), + [v.aval for v in known_invars]) call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars], core.closed_call_p, dict(call_jaxpr=call_jaxpr), @@ -596,14 +600,14 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn): which_linear=which_linear_unknown, unroll=unroll) - @lu.wrap_init def staged(*res_and_refs): out_flat = for_p.bind(*res_and_refs, **params_staged) _, ans = split_list(out_flat, [num_res]) _, ans = partition_list(out_inst, ans) return ans call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( - staged, [v.aval for v in [*resvars, *eqn.invars]]) + lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), + [v.aval for v in [*resvars, *eqn.invars]]) assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars) call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) _, outvars = partition_list(out_inst, eqn.outvars) @@ -621,7 +625,7 @@ def _convert_outputs_to_writes( assert not jaxpr.constvars, "Jaxpr shouldn't have constvars." in_avals = [v.aval for v in jaxpr.invars] # [i, *orig_ref_avals] - @lu.wrap_init + def eval_jaxpr(i, *refs): # We split the refs into the original input refs and the dummy residual # refs. @@ -641,7 +645,8 @@ def _convert_outputs_to_writes( v.aval.dtype)) # pytype: disable=attribute-error for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - eval_jaxpr, [*in_avals, *res_ref_avals]) + lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), + [*in_avals, *res_ref_avals]) assert not consts return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] # pytype: disable=attribute-error @@ -650,7 +655,6 @@ def _convert_inputs_to_reads( loop_invar_res: Sequence[bool]) -> core.Jaxpr: assert not jaxpr.constvars, "Jaxpr should not have constvars" - @lu.wrap_init def eval_jaxpr(i, *refs): residual_refs, orig_refs = split_list(refs, [num_res]) residual_vals = [r[()] if loop_invar else r[i] for r, loop_invar @@ -667,7 +671,8 @@ def _convert_inputs_to_reads( for aval, loop_invar in zip(res_val_avals, loop_invar_res)] jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( - eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals]) + lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), + [i_aval, *res_ref_avals, *orig_ref_avals]) return jaxpr def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr: @@ -693,7 +698,8 @@ def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr: ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ()) return [] jaxpr_trans, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) + lu.wrap_init(trans, debug_info=jaxpr.debug_info), + [v.aval for v in jaxpr.invars]) return jaxpr_trans def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear, unroll): @@ -746,8 +752,9 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64)) + debug = api_util.debug_info("discharged_for_loop", body, (0, init_state), {}) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( - body, state_tree, [idx_aval, *state_avals]) + body, state_tree, [idx_aval, *state_avals], debug) if out_tree != tree_structure(None): raise Exception("`body` should not return anything.") discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 33f70e216..60f96dec8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -511,7 +511,7 @@ def _empty_array(prefix, length_spec, aval): eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True -def _stage_jaxpr(trace, *tracers, jaxpr): +def _stage_jaxpr(trace: pe.JaxprTrace, *tracers, jaxpr: core.ClosedJaxpr): params = dict(call_jaxpr=jaxpr) return trace.default_process_primitive(core.closed_call_p, tracers, params) pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index f141a331b..ef1871f40 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -334,7 +334,9 @@ def _custom_linear_solve_impl(*args, const_lengths, jaxprs): return x -def _tangent_linear_map(func, params, params_dot, *x): +def _tangent_linear_map(func: Callable, params, params_dot, + debug_info: core.DebugInfo | None, + *x): """Compute the tangent of a linear map. Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``, @@ -342,7 +344,7 @@ def _tangent_linear_map(func, params, params_dot, *x): """ assert any(type(p) is not ad_util.Zero for p in params_dot) zeros = _map(ad_util.Zero.from_primal_value, x) - _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped( + _, out_tangent = ad.jvp(lu.wrap_init(func, debug_info=debug_info)).call_wrapped( params + list(x), params_dot + zeros) return out_tangent @@ -369,7 +371,8 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs): rhs = b_dot else: matvec_tangents = _tangent_linear_map( - core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves) + core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, + jaxprs.matvec.jaxpr.debug_info, *x_leaves) rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents)) x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b749aff28..823e971b6 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1847,8 +1847,8 @@ def reduce(operands: Any, return tree_util.tree_unflatten(out_tree, out) @cache() -def _reduction_jaxpr(computation, aval): - @lu.wrap_init +def _reduction_jaxpr(computation: Callable, + aval: core.AbstractValue): def comp(x, y): result = computation(x, y) if not (isinstance(result, core.Tracer) or core.valid_jaxtype(result)): @@ -1857,7 +1857,11 @@ def _reduction_jaxpr(computation, aval): f"Reduction functions should only return an array.\n" f"Full return value: {result}") return (result,) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp, (aval, aval)) + comp_wrapped = lu.wrap_init( + comp, + debug_info=api_util.debug_info("reduction_jaxpr", computation, + (aval, aval), {})) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp_wrapped, (aval, aval)) if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index bee179d08..2f8d7b329 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -282,6 +282,10 @@ class DebugInfo(NamedTuple): return self._replace(result_paths=tuple(self.result_paths())) return self + @property + def func_name(self) -> str: + return self.func_src_info.split(" ")[0] + def safe_arg_names(self, expected: int) -> tuple[str | None, ...]: """Get the arg_names with a safety check.""" if len(self.arg_names) == expected: @@ -328,7 +332,13 @@ def wrap_init(f: Callable, params=None, *, @transformation_with_aux2 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs): ans = _fun(*args, **kwargs) - _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) + result_paths = [keystr(path) for path, _ in generate_key_paths(ans)] + if _store: + # In some instances a lu.WrappedFun is called multiple times, e.g., + # the bwd function in a custom_vjp + assert _store.val == result_paths, (_store, result_paths) + else: + _store.store(result_paths) return ans def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 14893f34c..faf70200d 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2352,7 +2352,8 @@ pe.partial_eval_jaxpr_custom_rules[pjit_p] = \ @lu.cache -def _pjit_transpose_trace(fun, in_avals): +def _pjit_transpose_trace(fun: lu.WrappedFun, + in_avals: Sequence[core.AbstractValue]): transpose_jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( fun, in_avals) transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts) @@ -2360,13 +2361,15 @@ def _pjit_transpose_trace(fun, in_avals): def _pjit_transpose(cts_in, *primals_in, - jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, + jaxpr: core.ClosedJaxpr, + in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs): def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) - body = lu.wrap_init(ad.closed_backward_pass) + body = lu.wrap_init(ad.closed_backward_pass, + debug_info=jaxpr.jaxpr._debug_info) body = lu.hashable_partial(body, jaxpr, False) primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in)) body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index dc781c424..048217807 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -79,7 +79,8 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * , if isinstance(v.aval, AbstractRef) and d else v.aval for v, d in zip(jaxpr.invars, should_discharge)] eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, - should_discharge, consts)) + should_discharge, consts), + debug_info=jaxpr.debug_info) new_jaxpr, _ , new_consts, () = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) return new_jaxpr, new_consts diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 79618c2d1..63b3f09ad 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -29,6 +29,7 @@ import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec from jax._src import ad_checkpoint from jax._src import ad_util +from jax._src import api_util from jax._src import callback from jax._src import config from jax._src import core @@ -58,7 +59,6 @@ from jax._src.lib.mlir.dialects import sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, as_hashable_function, memoize, partition_list, split_list, subs_list2) -from jax.api_util import flatten_fun_nokwargs, argnums_partial from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -160,9 +160,10 @@ def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, @util.wraps(f) @traceback_util.api_boundary def wrapped(*args): - fun = lu.wrap_init(f) + fun = lu.wrap_init(f, + debug_info=api_util.debug_info("shard_map", f, args, {})) args_flat, in_tree = tree_flatten(args) - fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) try: in_specs_flat = broadcast_prefix(in_specs, args, is_leaf=lambda x: x is None) except ValueError: @@ -170,7 +171,7 @@ def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, raise e('shard_map in_specs') from None dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) if s is not None) - fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat, False) + fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False) _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) @@ -460,13 +461,16 @@ class ShardMapPrimitive(core.Primitive): return self._true_bind(*args, **params) def bind_with_trace(self, trace, fun_and_args, params): + fun: lu.WrappedFun fun, *args = fun_and_args return trace.process_shard_map(shard_map_p, fun, args, **params) def get_bind_params(self, params): new_params = dict(params) - jaxpr = new_params.pop('jaxpr') - subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ()) + jaxpr: core.Jaxpr = new_params.pop('jaxpr') + subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr, + debug_info=jaxpr.debug_info), + jaxpr, ()) axes = new_params.pop('out_names') new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes) return [subfun], new_params @@ -1511,7 +1515,8 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] ad.JVPTrace.process_shard_map = _shard_map_jvp -def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, +def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, + f: lu.WrappedFun, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] @@ -1519,7 +1524,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) all_names = _all_mesh_names_except_spmd(mesh, trace) in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals) - f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, False) + f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) f = _promote_scalar_residuals(f) f_known, aux = pe.partial_eval_wrapper_nounits( f, (*in_knowns,), (*in_avals_sharded,)) @@ -1545,7 +1550,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, res_names = [known_in_names[f1] if f1 is not None else known_out_names_[f2] if f2 is not None else {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] - unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) + unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment] const_tracers = map(trace.new_instantiated_const, res) env_tracers = map(trace.to_jaxpr_tracer, env) unk_arg_tracers = [t for t in tracers if not t.is_known()] @@ -1628,7 +1633,7 @@ def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs): return residuals + primals @lu.transformation2 -def _promote_scalar_residuals(f, *args, **kwargs): +def _promote_scalar_residuals(f: Callable, *args, **kwargs): jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) which = [f1 is None and f2 is None and not v.aval.shape for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] @@ -1686,7 +1691,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, return out fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) - fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree) + fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree) new_in_names = \ [n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \ diff --git a/tests/api_test.py b/tests/api_test.py index 918f81367..23c0e13fd 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9520,8 +9520,11 @@ class CustomVJPTest(jtu.JaxTestCase): def fwd(x): return np.array([2.0])*x*x/np.array([1.0]), (x,) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed @@ -9530,8 +9533,10 @@ class CustomVJPTest(jtu.JaxTestCase): return (np.array([1.0])*x)[0] def fwd(x): return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) @@ -9540,11 +9545,15 @@ class CustomVJPTest(jtu.JaxTestCase): return x def fwd(x): return x*x, (x,) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) def g(x): return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) - x = jnp.linspace(0, 5.0, 10) + self.assertAllClose(jax.jit(g)(x)[0], x*x) self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x) @@ -9553,7 +9562,10 @@ class CustomVJPTest(jtu.JaxTestCase): return x**2 def fwd_(x): return x*x, (x,) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_) + + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), + fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) calc = jax.jvp(fwd, (3.2,), (1.0,)) expected = jax.jvp(fwd_, (3.2,), (1.0,)) self.assertAllClose(calc, expected) diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index da98d81d5..93ddfe61b 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -31,9 +31,13 @@ import jax.custom_transpose from jax.experimental import checkify import jax.experimental.custom_dce from jax.experimental import pallas as pl +from jax.experimental.shard_map import shard_map import jax.numpy as jnp import jax.scipy as jsp +from jax.sharding import Mesh +from jax.sharding import PartitionSpec as P + from jax._src import api_util from jax._src.ad_checkpoint import saved_residuals from jax._src import config @@ -159,14 +163,17 @@ class DebugInfoTest(jtu.JaxTestCase): t_debug_info = _debug_info_to_string(t._debug_info) if check_tracer_arg_name: msg = str(exc) - m = re.match(r".* while tracing the function (.+) for (.+)\. .* depends on the value of the argument ([^\n]+)\.", + m = re.match(r".* while tracing the function (.+) for ([^.]+)\.", msg, re.DOTALL) self.assertIsNotNone(m, msg) self.assertEqual(t._debug_info.func_src_info, m.group(1)) self.assertEqual(t._debug_info.traced_for, m.group(2)) + m = re.match(r".* depends on the value of the argument ([^\n]+)\.", + msg, + re.DOTALL) found_tracer_debug_infos.append( - f"{t_debug_info}, from {m.group(3)}") + f"{t_debug_info}, from {m.group(1) if m else None}") else: found_tracer_debug_infos.append(t_debug_info) else: @@ -804,7 +811,8 @@ class DebugInfoTest(jtu.JaxTestCase): jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=", - "None", # TODO(necula): missing debug info + # TODO(necula): arg_names? + "traced_for=jit, fun=my_f, arg_names=None,None,None,None, result_paths=['a'],['b'][0][0]", ], tracer_spy=tracer_spy, check_tracer_arg_name=True, @@ -837,8 +845,8 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=[0],[1]", # TODO(necula): result_paths "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", - # TODO(necula): missing debug info - "None", + # TODO(necula): arg_names + "traced_for=jit, fun=my_g, arg_names=None,None,u,v, result_paths=['c'],['d']", ], check_tracer_arg_name=True, expected_tracer_debug_infos=[ @@ -854,6 +862,138 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"), ]) + def test_custom_jvp(self): + tracer_spy = TracerSpy() + @jax.custom_jvp + def my_fun(x, y, c=1.): + tracer_spy.append(y) + return c * (x + y) + def my_jvp(primals, tangents): + x, y, c = primals + t_x, t_y, t_c = tangents + tracer_spy.append(t_y) + return my_fun(x, y, c), t_c + my_fun.defjvp(my_jvp) + + def top_f(x, y): + return jnp.square(my_fun(x, y, c=2.)).sum() + + + self._check_tracers_and_jaxprs( + jax.jit(lambda a: jax.jvp(top_f, (a, a), + (jnp.ones_like(a), jnp.ones_like(a)))), + 42., + tracer_spy=tracer_spy, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=, arg_names=a, result_paths=[0],[1]", + "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=", + ], + check_tracer_arg_name=True, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=, arg_names=a, from None", + "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, from y", + ]) + + def test_custom_jvp_nondiff_args(self): + tracer_spy = TracerSpy() + def top_f(xy): + tracer_spy.append(xy[0]) + @functools.partial(jax.custom_jvp, nondiff_argnums=(0,)) + def my_g(h, xy): + x, y = xy + tracer_spy.append(x) + return h(x) + @my_g.defjvp + def my_g_jvp(h, primals, tangents): + (x, y), = primals + (xt, yt), = tangents + tracer_spy.append(xt) + return my_g(h, (x, y)), 2. * xt + h = lambda y: xy[0] + y # capture x + return my_g(h, xy) + + self._check_tracers_and_jaxprs( + jax.jit(lambda a, b: jax.jvp(top_f, ((a, b),), + ((jnp.ones_like(a), jnp.ones_like(b)),))), + 42., 43., + tracer_spy=tracer_spy, + expected_jaxpr_debug_infos=[ + # TODO(necula): arg_names + "traced_for=jit, fun=, arg_names=None,a,b, result_paths=[0],[1]", + "traced_for=custom_jvp fun, fun=my_g, arg_names=None,xy[0],xy[1], result_paths=", + ], + check_tracer_arg_name=True, + expected_tracer_debug_infos=[ + "traced_for=custom_jvp fun, fun=my_g, arg_names=xy[0],xy[1], from xy[0]", + # TODO(necula): from None + "traced_for=jit, fun=, arg_names=a,b, from None", + "None", # TODO(necula): None + ]) + + def test_custom_vjp(self): + tracer_spy = TracerSpy() + @jax.custom_vjp + def my_f(x): + tracer_spy.append(x["a"]) + return {"b": jnp.sin(x["a"])} + def my_f_fwd(x): + tracer_spy.append(x["a"]) + return my_f(x), {"r": jnp.cos(x["a"])} + def my_f_bwd(res, g): + tracer_spy.append(g["b"]) + cos_x = res["r"] + return ({"a": 2 * cos_x * g["b"]},) + my_f.defvjp(my_f_fwd, my_f_bwd) + + def to_diff(x): + return my_f(x)["b"] + + self._check_tracers_and_jaxprs( + jax.jit(jax.grad(to_diff)), + {"a" : 3.}, + tracer_spy=tracer_spy, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=['a']", + "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=['b']", + ], + check_tracer_arg_name=True, + expected_tracer_debug_infos=[ + "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']", + # TODO(necula): from None + "traced_for=jit, fun=to_diff, arg_names=x['a'], from None", + "traced_for=jit, fun=to_diff, arg_names=x['a'], from x['a']", + ]) + + def test_custom_vjp_nondiff_args(self): + tracer_spy = TracerSpy() + @functools.partial(jax.custom_vjp, nondiff_argnums=(0,)) + def app(f, xy): + tracer_spy.append(xy[0]) + return f(xy) + def app_fwd(f, xy): + tracer_spy.append(xy[0]) + return app(f, xy), jnp.cos(xy[0]) + def app_rev(f, cos_x0, g): + tracer_spy.append(cos_x0) + tracer_spy.append(g) + return ((cos_x0 * g, cos_x0),) + app.defvjp(app_fwd, app_rev) + + self._check_tracers_and_jaxprs( + jax.jit(jax.grad(lambda xy: app(lambda x: 2 * x[0], xy))), + (3., 3.), + tracer_spy=tracer_spy, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=, arg_names=xy[0],xy[1], result_paths=[0],[1]", + "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=", + ], + check_tracer_arg_name=True, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=, arg_names=xy[0],xy[1], from xy[0]", + "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]", + # TODO(necula): from None + "traced_for=jit, fun=, arg_names=xy[0],xy[1], from None", + ]) def test_vmap_of_nested_jit(self): tracer_spy = TracerSpy() @@ -995,8 +1135,12 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ 'traced_for=jit, fun=my_f, arg_names=x,y, result_paths=', - # TODO(necula): some Jaxprs without debug info - 'None', + # TODO(necula): arg_names? result_paths? + "traced_for=cond, fun=my_true_branch, arg_names=None, result_paths=,", + "traced_for=cond, fun=my_false_branch, arg_names=None, result_paths=,", + "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=[0],[1]", + "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=[0],[1]", + "traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=,", ], expected_tracer_debug_infos=[ 'traced_for=cond, fun=my_true_branch, arg_names=a,b', @@ -1020,7 +1164,9 @@ class DebugInfoTest(jtu.JaxTestCase): @jax.jit def my_f(x, as_): tracer_spy.append(x) - return jax.remat(lambda *args: for_loop.scan(f, *args))(c, as_) + def to_remat(a, b): + return for_loop.scan(f, a, b) + return jax.remat(to_remat)(c, as_) def the_grad(c, as_): tracer_spy.append(c) @@ -1035,13 +1181,21 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]", # TODO(necula): bad result paths "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", - 'None', # TODO(necula): some Jaxprs without debug info + # TODO(necula): arg_names? + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=for_loop, fun=f, arg_names=None,None,None, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None, result_paths=", + "traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None,None,None,None,None, result_paths=,", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=None,None,None, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=None,None,x,as_, result_paths=", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_", "traced_for=scan, fun=f, arg_names=c,a", "traced_for=jit, fun=my_f, arg_names=x,as_", - 'None', # TODO(necula): some missing debug info + # TODO(necula): arg_names + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2]", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"c\"\)"), @@ -1321,12 +1475,44 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - # TODO(necula): some Jaxprs without debug info - "None"], + # TODO(necula): arg_names? + "traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=", + ], expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=my_g, arg_names=y" ]) + def test_remat_shard_map(self): + tracer_spy = TracerSpy() + if len(jax.devices()) < 2: + self.skipTest("requires at least 2 devices") + # this tests remat-of-shmap + mesh = Mesh(np.array(jax.devices()[:2]), ('x',)) + + # check param updating is handled + @jax.remat + @functools.partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) + def my_f(x): + tracer_spy.append(x) + return jnp.sin(jnp.sin(x)) + + self._check_tracers_and_jaxprs( + jax.jit(jax.grad(lambda x: my_f(x).sum())), + jnp.arange(2, dtype=np.float32), + tracer_spy=tracer_spy, + expected_jaxpr_debug_infos=[ + # TODO(necula): arg_names + "traced_for=jit, fun=, arg_names=x, result_paths=", + "traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=", + "traced_for=shard_map, fun=my_f, arg_names=x, result_paths=", + "traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=", + "None", # TODO(necula): missing + ], + check_tracer_arg_name=True, + expected_tracer_debug_infos=[ + "None" # TODO(necula): missing + ]) + def test_remat_saved_residuals(self): @functools.partial(jax.remat, static_argnums=(1,), @@ -1359,7 +1545,7 @@ class DebugInfoTest(jtu.JaxTestCase): # TODO(necula): this should not be pointing into the JAX internals re.compile(r"traced_for=jit, fun=checked_fun at .*jax/_src/checkify.py:.*, arg_names=args\[0\]"), re.compile(r"traced_for=jit, fun=argsort at .*numpy/lax_numpy.py:.*, arg_names=a, result_paths="), - "None", # TODO(necula): missing tracer debug info + "traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=[0]", ], expected_tracer_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=my_x", @@ -1495,7 +1681,8 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=[0],[1]", - "None", # TODO(necula): there are missing Jaxpr debug info + # TODO(necula): internal function? + re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*/control_flow/solves.py:.*, arg_names=args\[0\], result_paths=\[0\]"), ], expected_tracer_debug_infos=[ "traced_for=custom_root, fun=my_f, arg_names=x",