From 904b74860ce74723ec2c04a1ecf1b2112327ad3d Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 5 Feb 2025 19:17:47 +0200 Subject: [PATCH] [better_errors] Continue adding debug info to Jaxprs (step 3) This follows after #26078, and #26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`. --- jax/_src/api.py | 12 +++- jax/_src/checkify.py | 23 ++++--- jax/_src/core.py | 2 +- jax/_src/custom_derivatives.py | 64 +++++++++++-------- jax/_src/interpreters/ad.py | 45 +++++++++---- jax/_src/interpreters/batching.py | 16 +++-- jax/_src/interpreters/partial_eval.py | 10 ++- jax/_src/lax/control_flow/loops.py | 29 ++++++--- jax/_src/lax/lax.py | 13 ++-- jax/_src/lax/windowed_reductions.py | 5 +- jax/_src/pallas/hlo_interpreter.py | 2 +- jax/experimental/jax2tf/BUILD | 2 - .../jax2tf/tests/back_compat_testdata/BUILD | 1 - .../jax2tf/tests/flax_models/BUILD | 1 - jax/experimental/shard_map.py | 41 +++++++----- jax/experimental/sparse/transform.py | 4 +- tests/debug_info_test.py | 63 +++++++++++++++++- 17 files changed, 234 insertions(+), 99 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 581b2b512..e531c96f4 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -573,7 +573,11 @@ def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0, @wraps(fun, docstr=docstr, argnums=argnums) def jacfun(*args, **kwargs): - f = lu.wrap_init(fun, kwargs) + f = lu.wrap_init( + fun, kwargs, + debug_info=debug_info( + "jacfwd", fun, args, kwargs, + static_argnums=(argnums,) if isinstance(argnums, int) else argnums)) f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args) @@ -661,7 +665,11 @@ def jacrev(fun: Callable, argnums: int | Sequence[int] = 0, @wraps(fun, docstr=docstr, argnums=argnums) def jacfun(*args, **kwargs): - f = lu.wrap_init(fun, kwargs) + f = lu.wrap_init( + fun, kwargs, + debug_info=debug_info( + "jacrev", fun, args, kwargs, + static_argnums=(argnums,) if isinstance(argnums, int) else argnums)) f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 9eea81b9d..821518009 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -1065,18 +1065,20 @@ def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk): return [*primal_errs, *out_primals, *tangent_errs, *out_tangents] return jvp -def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, - fwd_jaxpr_thunk, num_consts, bwd, out_trees, - symbolic_zeros): +def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, + fun_jaxpr: core.ClosedJaxpr, + fwd_jaxpr_thunk, num_consts, + bwd: lu.WrappedFun, out_trees, + symbolic_zeros: bool): err_vals, err_tree = jtu.tree_flatten(in_err) num_errs = err_tree.num_leaves checkified_fun = lu.wrap_init( functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, - fun_jaxpr.consts, enabled_errors, err_tree)) + fun_jaxpr.consts, enabled_errors, err_tree), + debug_info=fun_jaxpr.jaxpr.debug_info) checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk( checkified_fun) - @lu.wrap_init def checkified_fwd(*args): # TODO(lenamartens, sharadmv): why not checkify here? xs, zeros = args[::2], args[1::2] @@ -1085,10 +1087,15 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) - bwd_ = lambda *args: (*(None,)*num_errs, *bwd(*args)) - checkified_fwd, fwd_out_tree = flatten_fun_output(checkified_fwd) + # TODO(necula): the fwd result_paths are not quite the same as fun_jaxpr + checkified_fwd_wrapped = lu.wrap_init(checkified_fwd, + debug_info=fun_jaxpr.jaxpr.debug_info) + bwd_ = lu.wrap_init(lambda *args: (*(None,)*num_errs, *bwd.call_wrapped(*args)), + debug_info=bwd.debug_info) + checkified_fwd_wrapped, fwd_out_tree = flatten_fun_output(checkified_fwd_wrapped) all_outs = custom_derivatives.custom_vjp_call_p.bind( - checkified_fun, checkified_fwd, bwd_, *err_vals, *in_vals, out_trees=out_trees, + checkified_fun, checkified_fwd_wrapped, + bwd_, *err_vals, *in_vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree) if fst: diff --git a/jax/_src/core.py b/jax/_src/core.py index 2aefe1544..2681bf939 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -392,7 +392,7 @@ class JaxprEqn: # TODO(mattjj): call typecheck rules here, so we don't form bad eqns def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, - ctx=None): + ctx=None) -> JaxprEqn: source_info = source_info or source_info_util.new_source_info() ctx = ctx or JaxprEqnContext( compute_on.current_compute_type(), diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index c8f8cc7a3..5fafb007d 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -374,15 +374,16 @@ class CustomJVPCallPrimitive(core.Primitive): def get_bind_params(self, params): new_params = dict(params) - call_jaxpr = new_params.pop('call_jaxpr') - num_consts = new_params.pop('num_consts') + call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') + num_consts: int = new_params.pop('num_consts') jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk') - fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr)) - jvp = lift_jvp(num_consts, jvp_jaxpr_thunk) + fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info) return [fun, jvp], new_params -def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable) -> lu.WrappedFun: - @lu.wrap_init +def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable, + debug_info: core.DebugInfo | None) -> lu.WrappedFun: def jvp(*xs): n, ragged = divmod(len(xs), 2) assert not ragged @@ -398,7 +399,7 @@ def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable) -> lu.WrappedFun: for p, z in zip(out_primals, out_zeros)] assert next(nz_out_tangents_, None) is None return [*out_primals, *out_tangents] - return jvp + return lu.wrap_init(jvp, debug_info=debug_info) effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect) @@ -435,8 +436,9 @@ def _custom_jvp_call_transpose(params, jaxpr, args, ct, _): ad.primitive_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose @weakref_lru_cache -def _cached_closed_call_dce_instantiate(jaxpr_, used_outputs: tuple[bool, ...] - ) -> tuple[core.ClosedJaxpr, list[bool]]: +def _cached_closed_call_dce_instantiate(jaxpr_: core.ClosedJaxpr, + used_outputs: tuple[bool, ...] + ) -> tuple[core.ClosedJaxpr, list[bool]]: jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs, True) return core.ClosedJaxpr(new_jaxpr, consts), used_inputs @@ -673,7 +675,7 @@ class custom_vjp(Generic[ReturnValue]): flat_fwd, out_trees = _flatten_fwd( fwd_, self.nondiff_argnums, self.symbolic_zeros, debug_fun, debug_fwd, in_tree, out_type) - flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped + flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees) out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees, symbolic_zeros=self.symbolic_zeros) @@ -940,7 +942,9 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun( def _custom_vjp_call_jaxpr_jvp( primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): + num_consts: int, bwd: lu.WrappedFun, + out_trees: Callable[[], Sequence[PyTreeDef]], + symbolic_zeros: bool): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot): @@ -963,7 +967,8 @@ def _custom_vjp_call_jaxpr_vmap( axis_data, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): + num_consts: int, bwd: lu.WrappedFun, + out_trees: Callable, symbolic_zeros: bool): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] in_batched = [d is not not_mapped for d in in_dims] @@ -1000,12 +1005,13 @@ def _custom_vjp_call_jaxpr_dce( ) -> tuple[list[bool], core.JaxprEqn | None]: if not any(used_outs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None - - fun_jaxpr = eqn.params["fun_jaxpr"] + fun_jaxpr: core.ClosedJaxpr = eqn.params["fun_jaxpr"] fwd_jaxpr_thunk = eqn.params["fwd_jaxpr_thunk"] - bwd = eqn.params["bwd"] - out_trees = eqn.params["out_trees"] - symbolic_zeros = eqn.params["symbolic_zeros"] + bwd: lu.WrappedFun = eqn.params["bwd"] + out_trees: Callable[[], Sequence[PyTreeDef]] = eqn.params["out_trees"] + symbolic_zeros: bool = eqn.params["symbolic_zeros"] + dce_fun_jaxpr: core.ClosedJaxpr + used_ins: Sequence[bool] dce_fun_jaxpr, used_ins = _cached_closed_call_dce_instantiate( fun_jaxpr, tuple(used_outs)) assert all(used_ins) @@ -1019,7 +1025,6 @@ def _custom_vjp_call_jaxpr_dce( fwd_jaxpr, (True,) * num_res + tuple(used_outs)) return dce_fwd_jaxpr.jaxpr, dce_fwd_jaxpr.consts - @lu.wrap_init def dce_bwd(*args): _, res_tree = out_trees() res, cts = split_list(args, [res_tree.num_leaves]) @@ -1035,19 +1040,21 @@ def _custom_vjp_call_jaxpr_dce( else: all_cts.append(zeros_like_aval(ct_aval)) assert next(cts_, None) is None - return bwd(*res, *all_cts) + return bwd.call_wrapped(*res, *all_cts) + dce_bwd_wrapped = lu.wrap_init(dce_bwd, + debug_info=bwd.debug_info) outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] new_params = dict( eqn.params, fun_jaxpr=dce_fun_jaxpr, fwd_jaxpr_thunk=dce_fwd_jaxpr_thunk, - bwd=dce_bwd.call_wrapped, + bwd=dce_bwd_wrapped, ) new_eqn = pe.new_jaxpr_eqn( eqn.invars, outvars, eqn.primitive, new_params, dce_fun_jaxpr.effects, eqn.source_info, eqn.ctx) - return used_ins, new_eqn + return list(used_ins), new_eqn pe.dce_rules[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_dce xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) @@ -1125,7 +1132,9 @@ def custom_gradient(fun): def fwd(*args, **kwargs): ans, rule = fun(*args, **kwargs) ans_flat, out_tree = tree_flatten((ans,)) - rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) + debug_fwd = debug_info("custom_gradient fwd", rule, (ans,), {}) + rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule, + debug_info=debug_fwd), out_tree) ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals) return ans, Residuals(jaxpr, in_tree(), out_tree, consts) @@ -1224,10 +1233,11 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]: """ flat_args, in_tree = tree_flatten(example_args) in_avals = tuple(map(core.get_aval, flat_args)) + debug = debug_info("closure_convert", fun, example_args, {}) if config.check_tracer_leaks.value: - return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals) + return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals, debug) else: - return _closure_convert_for_avals(fun, in_tree, in_avals) + return _closure_convert_for_avals(fun, in_tree, in_avals, debug) def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value @@ -1251,8 +1261,10 @@ def _maybe_perturbed(x: Any) -> bool: return True # We can't be sure! @cache() -def _closure_convert_for_avals(fun, in_tree, in_avals): - wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) +def _closure_convert_for_avals(fun, in_tree, in_avals, + debug_info: core.DebugInfo): + wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun, debug_info=debug_info), + in_tree) jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 8d0f0e7ca..307c620ca 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -25,7 +25,7 @@ from jax._src import config from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax.tree_util import (tree_flatten, tree_unflatten, - register_pytree_node, Partial) + register_pytree_node, Partial, PyTreeDef) from jax._src import core from jax._src import source_info_util from jax._src.ad_util import ( @@ -644,26 +644,28 @@ class LinearizeTrace(Trace): return [maybe_linearize_tracer(self, x, nz, t) for x, nz, t in zip(primals_out, tangent_nzs_out, tangents_out)] - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, - symbolic_zeros): + def process_custom_vjp_call(self, prim, fun, fwd, + bwd: lu.WrappedFun, tracers, + out_trees: Callable[[], Sequence[PyTreeDef]], + symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): return prim.bind_with_trace(self.parent_trace, (fun, fwd, bwd, *primals_in), dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)] - fwd_in = [x for pair in fwd_in for x in pair] # flatten + fwd_in_flat = [x for pair in fwd_in for x in pair] # flatten with core.set_current_trace(self.parent_trace): - res_and_primals_out = fwd.call_wrapped(*fwd_in) + res_and_primals_out = fwd.call_wrapped(*fwd_in_flat) _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - tangents_in = map(instantiate_zeros, tangents_in) + tangents_in_zeros = map(instantiate_zeros, tangents_in) with core.set_current_trace(self.tangent_trace): tangents_out = custom_lin_p.bind( - *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, + *res, *tangents_in_zeros, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangent_nzs_out = [type(t) is not Zero for t in tangents_out] return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out) @@ -975,9 +977,13 @@ def nonzero_outputs(f, store, *args, **kwargs): store.store([type(r) is not Zero for r in results]) return results -def map_transpose(primitive, params, call_jaxpr, args, ct, _): +def map_transpose(primitive: core.Primitive, params, + call_jaxpr: core.Jaxpr, args, ct, _): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts - fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, False) + # TODO(necula): use the right debug_info for the backwards pass + fun = lu.hashable_partial(lu.wrap_init(backward_pass, + debug_info=call_jaxpr.debug_info), + call_jaxpr, False) fun, nz_arg_cts = nonzero_outputs(fun) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) # Preserve axis for primal arguments, skip tangents (represented as undefined primals). @@ -1056,12 +1062,24 @@ def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents): def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) + new_debug_info = jaxpr.jaxpr.debug_info + if new_debug_info is not None: + new_arg_names = tuple(_perm(primals_in, tangents_in, + jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars)))) + new_result_paths = tuple(_perm(primals_out, tangents_out, + jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars)))) + new_debug_info = new_debug_info._replace( + arg_names=new_arg_names, + result_paths=new_result_paths, + ) new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, - jaxpr.jaxpr.effects) + jaxpr.jaxpr.effects, + new_debug_info) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) -def _perm(primal_counts, tangent_counts, lst): +def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int], + lst: Sequence[Any]) -> Sequence[Any]: n = sum(primal_counts) primals, tangents = lst[:n], lst[n:] primal_groups = split_list(primals, primal_counts[:-1]) @@ -1082,14 +1100,15 @@ def raise_custom_vjp_error_on_jvp(*_, **__): "function.") custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp) -def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals, +def _custom_lin_transpose(cts_out, *invals, num_res, + bwd: lu.WrappedFun, out_avals, symbolic_zeros): res, _ = split_list(invals, [num_res]) if symbolic_zeros: cts_out = map(replace_internal_symbolic_zeros, cts_out) else: cts_out = map(instantiate_zeros, cts_out) - cts_in = bwd(*res, *cts_out) + cts_in = bwd.call_wrapped(*res, *cts_out) cts_in = map(replace_rule_output_symbolic_zeros, cts_in) return [None] * num_res + list(cts_in) primitive_transposes[custom_lin_p] = _custom_lin_transpose diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 1366c3d6c..36f75c533 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -799,8 +799,11 @@ def batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): return _batch_jaxpr_axes(closed_jaxpr, axis_data, tuple(in_axes), tuple(out_axes_dest)) @weakref_lru_cache -def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): - f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) +def _batch_jaxpr_axes(closed_jaxpr: core.ClosedJaxpr, + axis_data: AxisData, + in_axes: Sequence[int], out_axes_dest: Sequence[int]): + f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr), + debug_info=closed_jaxpr.jaxpr.debug_info) f, out_axes = _batch_jaxpr_inner(f, axis_data) f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes) f = _batch_jaxpr_outer(f, axis_data, in_axes) @@ -896,7 +899,10 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals): store.store(out_dims * 2) return out_primals + out_tangents -def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests): +def batch_custom_vjp_bwd(bwd: lu.WrappedFun, tag: core.TraceTag, + axis_data: AxisData, + in_dims: Callable[[], Sequence[int | None]], + out_dim_dests: Sequence[int | None]) -> lu.WrappedFun: axis_size = axis_data.size axis_name = axis_data.name mesh_axis = axis_data.explicit_mesh_axis @@ -907,11 +913,11 @@ def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests): for x, dim in zip(args, in_dims_)] in_dims_ = [None if type(x) is SymbolicZero else d for x, d in zip(args, in_dims_)] - bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd), tag, axis_data, in_dims_) + bwd_, out_dims_thunk = batch_subtrace(bwd, tag, axis_data, in_dims_) bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, mesh_axis, out_dims_thunk, out_dim_dests) return bwd_.call_wrapped(*args) - return new_bwd + return lu.wrap_init(new_bwd, debug_info=bwd.debug_info) @lu.transformation2 def _match_axes_and_sum(f, axis_size, axis_name, mesh_axis, out_dims_thunk, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 721386565..bbc5f7e49 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2022,8 +2022,11 @@ class DynamicJaxprTrace(core.Trace): self.frame.add_eqn(eqn) return out_tracers - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, - symbolic_zeros): + def process_custom_vjp_call(self, prim: core.Primitive, + fun: lu.WrappedFun, + fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, + out_trees: Callable[[], Sequence[PyTreeDef]], + symbolic_zeros: bool): tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) @@ -2041,7 +2044,8 @@ class DynamicJaxprTrace(core.Trace): invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, + prim.initial_style, # type: ignore[attribute-error] dict(fun_jaxpr=closed_fun_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, num_consts=len(consts), diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 60f96dec8..fd6715295 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -586,8 +586,10 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, for p, nz in zip(primals_out, nonzeros_out)] return primals_out, tangents_out -def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, - jaxpr, linear, unroll, _split_transpose): +def _scan_partial_eval(trace, *tracers, reverse: bool, + length: int, num_consts: int, num_carry: int, + jaxpr: core.ClosedJaxpr, linear: Sequence[bool], + unroll: int, _split_transpose: bool): num_ys = len(jaxpr.out_avals) - num_carry unknowns = [not t.pval.is_known() for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) @@ -612,8 +614,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, del res_avals, carry_uk_out # Instantiate those inputs which must be treated as unknown from the fixpoint. - tracers = [trace.instantiate_const(t) if uk else t - for t, uk in zip(tracers, unknowns)] + tracers = tuple(trace.instantiate_const(t) if uk else t + for t, uk in zip(tracers, unknowns)) # The residual inputs and outputs of the jaxprs produced haven't yet been # adapted to the scan calling convention; in particular, jaxpr_known has its @@ -638,7 +640,9 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, for aval in jaxpr_known.in_avals[len(const_pvals):]] with source_info_util.reset_name_stack(): jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits( - lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals, + lu.wrap_init(core.jaxpr_as_fun(jaxpr_known), + debug_info=jaxpr_known.jaxpr.debug_info), + const_pvals + other_pvals, instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res) jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) # The above trace_to_jaxpr_nounits call computed loop-invariant residuals @@ -880,8 +884,9 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, # transpose_scan_jaxpr :: ([res1, c, a, res2] -> b) # -> ([res1, CT c, CT b, res2] -> [CT c, CT a]) @weakref_lru_cache -def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2, - ct_ys_is_zeros): +def _transpose_scan_jaxpr(jaxpr: core.ClosedJaxpr, + num_res1: int, num_c: int, num_res2: int, + ct_ys_is_zeros: Sequence[bool]): num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2 # TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals # if an axis isn't reduced @@ -896,7 +901,6 @@ def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2, aval for aval, is_zero in zip(b_ys_avals, ct_ys_is_zeros) if not is_zero ] - @lu.wrap_init def transposed(*res1_cbar_bbar_res2): res1, c_bar, b_bar, ys_bar_stripped, res2 = split_list( res1_cbar_bbar_res2, @@ -915,9 +919,14 @@ def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2, a_bar = _map(ad.instantiate_zeros, a_bar) c_bar = _map(ad.instantiate_zeros, _map(ad.add_tangents, c_bar, new_c_bar)) return c_bar + a_bar + + # TODO(necula): fix arg names and results for transposed + transposed_wrapped = lu.wrap_init(transposed, + debug_info=jaxpr.jaxpr.debug_info) return _make_closed_jaxpr_attrs( - transposed, tuple(res1_avals + c_avals + b_carry_avals + - b_ys_avals_stripped + res2_avals)) + transposed_wrapped, + tuple(res1_avals + c_avals + b_carry_avals + + b_ys_avals_stripped + res2_avals)) def _scan_batching_rule(axis_data, args, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 68f307654..a3b5e13ef 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1923,6 +1923,8 @@ def reduce(operands: Any, is undefined. """ flat_operands, operand_tree = tree_util.tree_flatten(operands) + comp_debug = api_util.debug_info("reduce comp", computation, + (init_values, init_values), {}) flat_init_values, init_value_tree = tree_util.tree_flatten(init_values) if operand_tree != init_value_tree: raise ValueError('Operands must have the same tree structure as init_values:' @@ -1939,7 +1941,7 @@ def reduce(operands: Any, else: flat_init_avals = safe_map(core.get_aval, flat_init_values) closed_jaxpr, out_tree = _variadic_reduction_jaxpr( - computation, tuple(flat_init_avals), init_value_tree) + computation, comp_debug, tuple(flat_init_avals), init_value_tree) out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation, jaxpr=closed_jaxpr, dimensions=tuple(dimensions)) return tree_util.tree_unflatten(out_tree, out) @@ -1967,10 +1969,13 @@ def _reduction_jaxpr(computation: Callable, return jaxpr, tuple(consts) @cache() -def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree): +def _variadic_reduction_jaxpr(computation: Callable[[Any, Any], Any], + debug_info: core.DebugInfo, + flat_avals, + aval_tree: tree_util.PyTreeDef): avals = tree_util.tree_unflatten(aval_tree, flat_avals) flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals)) - comp = lu.wrap_init(computation) + comp = lu.wrap_init(computation, debug_info=debug_info) flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree) jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals)) if any(isinstance(c, core.Tracer) for c in consts): @@ -5921,7 +5926,7 @@ def _argminmax_dtype_rule(operand, *, axes, index_dtype): class _ArgMinMaxReducer: - def __init__(self, value_comparator): + def __init__(self, value_comparator: Callable[[Any, Any], Any]): self._value_comparator = value_comparator def __repr__(self): diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index d1ad4e75a..81eac41d5 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -19,6 +19,7 @@ from functools import partial import warnings from jax import tree_util +from jax._src import api_util from jax._src import core from jax._src import dispatch from jax._src import dtypes @@ -55,6 +56,8 @@ def _reduce_window( window_dilation: Sequence[int] | None = None, ): flat_operands, operand_tree = tree_util.tree_flatten(operand) + comp_debug = api_util.debug_info("reduce_window comp", computation, + (init_value, init_value), {}) flat_init_values, init_value_tree = tree_util.tree_flatten(init_value) if operand_tree != init_value_tree: raise ValueError( @@ -88,7 +91,7 @@ def _reduce_window( else: flat_init_avals = map(core.get_aval, flat_init_values) jaxpr, out_tree = lax._variadic_reduction_jaxpr( - computation, tuple(flat_init_avals), init_value_tree + computation, comp_debug, tuple(flat_init_avals), init_value_tree ) if operand_tree != out_tree: raise ValueError( diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 7ad540cf3..a21fc1fb0 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -326,7 +326,7 @@ def resolve_physical_types(jaxpr: jax_core.Jaxpr, consts: Sequence[Any]): interp_fun = partial( eval_jaxpr_recursive, jaxpr, consts, recurse_hop_rule=resolve_physical_types) - wrapped = lu.wrap_init(interp_fun) + wrapped = lu.wrap_init(interp_fun, debug_info=jaxpr.debug_info) new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic( wrapped, kernel_avals) return new_jaxpr, new_consts diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index d60b4c333..85ad90326 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -30,7 +30,6 @@ package( py_library( name = "jax2tf", srcs = ["__init__.py"], - srcs_version = "PY3", visibility = ["//visibility:public"], deps = [":jax2tf_internal"], ) @@ -42,7 +41,6 @@ py_library( "impl_no_xla.py", "jax2tf.py", ], - srcs_version = "PY3", # TODO: b/255503696: enable pytype tags = ["pytype_unchecked_annotations"], visibility = jax_visibility("jax2tf_internal"), diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD index 3417c1abf..d166f1308 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD @@ -24,7 +24,6 @@ package( py_library( name = "back_compat_testdata", srcs = glob(["*.py"]), - srcs_version = "PY3", deps = [ "//third_party/py/numpy", "//third_party/py/typing_extensions", diff --git a/jax/experimental/jax2tf/tests/flax_models/BUILD b/jax/experimental/jax2tf/tests/flax_models/BUILD index 331d4ab8e..c028f13a4 100644 --- a/jax/experimental/jax2tf/tests/flax_models/BUILD +++ b/jax/experimental/jax2tf/tests/flax_models/BUILD @@ -28,7 +28,6 @@ package( py_library( name = "flax_models", srcs = glob(["*.py"]), - srcs_version = "PY3", deps = [ "//jax", "//third_party/py/flax:core", diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 11538a368..8ab179b7f 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1642,8 +1642,8 @@ def _promote_scalar_residuals(f: Callable, *args, **kwargs): for x in out_consts] return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) -def _promote_scalar_residuals_jaxpr(jaxpr, which): - @lu.wrap_init +def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, + which: Sequence[bool]): def fun(*res_and_args): res, args = split_list(res_and_args, [len(jaxpr.constvars)]) res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] @@ -1651,7 +1651,8 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which): res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval for v, w in zip(jaxpr.constvars, which)] in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals) return jaxpr @@ -1663,20 +1664,20 @@ def _unmentioned2(mesh: Mesh, names: AxisNames, return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set] -def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, +def _shard_map_transpose(out_cts, *args, + jaxpr: core.Jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero else x if rewrite or dtypes.dtype(x) == dtypes.float0 else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) for ns, x in zip(out_names, out_cts)] - args = [x if type(x) is not ad.UndefinedPrimal else - ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) - for ns, x in zip(in_names, args)] + args = tuple(x if type(x) is not ad.UndefinedPrimal else + ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) + for ns, x in zip(in_names, args)) all_args, in_tree = tree_flatten((out_cts, args)) - @lu.wrap_init - def fun_trans(out_cts, args): + def fun_trans_callable(out_cts, args): res, undefs = partition_list(map(ad.is_undefined_primal, args), args) jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits( pe.close_jaxpr(jaxpr), map(ad.is_undefined_primal, args), False) @@ -1690,6 +1691,8 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, for ns, x in zip(in_names, out)] return out + fun_trans = lu.wrap_init(fun_trans_callable, + debug_info=jaxpr.debug_info) fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree) @@ -1986,8 +1989,11 @@ class RewriteTrace(core.Trace): out_reps = out_reps[:len(out_reps) // 2] return map(partial(RewriteTracer, self), out_reps, out_vals) - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, - symbolic_zeros): + def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, + fwd: lu.WrappedFun, bwd: lu.WrappedFun, + tracers, + out_trees: Callable[[], Sequence[PyTreeDef]], + symbolic_zeros: bool): if symbolic_zeros: msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " "as a temporary workaround pass the check_rep=False argument to " @@ -2055,13 +2061,15 @@ def _replication_rewrite_nomatch( jaxpr: core.ClosedJaxpr, in_rep: Sequence[set[AxisName]], ) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]: - f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) + f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts), + debug_info=jaxpr.jaxpr.debug_info) f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts), out_rep() @lu.transformation_with_aux2 -def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals): +def _rewrite_subtrace(f: Callable, store: lu.Store, + tag: core.TraceTag, mesh: Mesh, in_reps, *in_vals): with core.take_current_trace() as parent_trace: assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) t = RewriteTrace(parent_trace, tag, mesh) @@ -2072,13 +2080,14 @@ def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals): store.store(out_reps) return out_vals -def _rewrite_bwd(bwd, mesh, in_reps, reps_dst): +def _rewrite_bwd(bwd: lu.WrappedFun, + mesh: Mesh, in_reps, reps_dst) -> lu.WrappedFun: def new_bwd(*args): tag = core.TraceTag() - bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), tag, mesh, in_reps()) + bwd_, reps_thunk = _rewrite_subtrace(bwd, tag, mesh, in_reps()) out = bwd_.call_wrapped(*args) return map(_match_replication, reps_thunk(), reps_dst, out) - return new_bwd + return lu.wrap_init(new_bwd, debug_info=bwd.debug_info) def _match_replication(src, dst, x): if dst - src: diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 76d4d957e..c0b643061 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -865,7 +865,7 @@ sparse_rules_bcoo[sparse.todense_p] = _todense_sparse_rule sparse_rules_bcsr[sparse.todense_p] = _todense_sparse_rule def _custom_jvp_sparse_rule(spenv, *spvalues, **params): - call_jaxpr = params.pop('call_jaxpr') + call_jaxpr: core.ClosedJaxpr = params.pop('call_jaxpr') jvp_jaxpr_thunk = params.pop('jvp_jaxpr_thunk') num_consts = params.pop('num_consts') sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, call_jaxpr, *spvalues) @@ -874,7 +874,7 @@ def _custom_jvp_sparse_rule(spenv, *spvalues, **params): sparrs = arrays_to_spvalues(spenv, arrs) out = eval_sparse(call_jaxpr.jaxpr, call_jaxpr.consts, sparrs, spenv) return spvalues_to_arrays(spenv, out) - jvp = lift_jvp(num_consts, jvp_jaxpr_thunk) + jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info) invals = spvalues_to_arrays(spenv, spvalues) outvals = jax.custom_derivatives.custom_jvp_call_p.bind(fun, jvp, *invals, **params) return arrays_to_spvalues(spenv, outvals) diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 93ddfe61b..8ba571e6d 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -96,8 +96,12 @@ class TracerSpy: def __init__(self): self.tracers = [] - def append(self, t: core.Tracer): + def append(self, t: Any): try: + # We plan to do boolean conversion and catch the exception, but this works + # only for scalars + if isinstance(t, core.Tracer) and t.shape: + t = jnp.sum(t) if t: pass assert False, t @@ -862,6 +866,32 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"), ]) + def test_vjp_remat(self): + tracer_spy = TracerSpy() + def apply_fn(inp): + tracer_spy.append(inp) + def to_remat(x): + tracer_spy.append(x) + return jax.nn.relu(x * x) + fn = jax.checkpoint(to_remat) + return jax.vjp(fn, inp) + + self._check_tracers_and_jaxprs( + jax.jit(apply_fn), + 2., + tracer_spy=tracer_spy, + expected_jaxpr_debug_infos=[ + # TODO(necula): what are these flat_index components? + "traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][][0][][0][0]", + re.compile(r"traced_for=custom_jvp fun, fun=relu at .*/nn/functions.py:.*, arg_names=x, result_paths="), + re.compile(r"traced_for=jit, fun=relu at .*/nn/functions.py:.*, arg_names=x, result_paths="), + ], + check_tracer_arg_name=True, + expected_tracer_debug_infos=[ + "traced_for=checkpoint / remat, fun=to_remat, arg_names=x, from x", + "traced_for=jit, fun=apply_fn, arg_names=inp, from inp", + ]) + def test_custom_jvp(self): tracer_spy = TracerSpy() @jax.custom_jvp @@ -959,7 +989,7 @@ class DebugInfoTest(jtu.JaxTestCase): check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']", - # TODO(necula): from None + # TODO(necula): from None? "traced_for=jit, fun=to_diff, arg_names=x['a'], from None", "traced_for=jit, fun=to_diff, arg_names=x['a'], from x['a']", ]) @@ -1435,6 +1465,33 @@ class DebugInfoTest(jtu.JaxTestCase): ], ) + def test_hessian(self): + tracer_spy = TracerSpy() + + def my_f(x): + tracer_spy.append(x) + return jnp.square(x).mean() + + x = jax.random.uniform(jax.random.key(0), shape=(8, 4)) + + self._check_tracers_and_jaxprs( + jax.jit(jax.hessian(jax.jit(my_f))), + x, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=x, result_paths=", + # TODO(necula): arg_names and result_paths? + "traced_for=jit, fun=my_f, arg_names=None,x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", + ], + tracer_spy=tracer_spy, + check_tracer_arg_name=True, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=x, from x", + ], + ) + + (x).block_until_ready() + def test_remat(self): tracer_spy = TracerSpy() def my_f(x): @@ -1506,7 +1563,6 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=", "traced_for=shard_map, fun=my_f, arg_names=x, result_paths=", "traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=", - "None", # TODO(necula): missing ], check_tracer_arg_name=True, expected_tracer_debug_infos=[ @@ -1657,6 +1713,7 @@ class DebugInfoTest(jtu.JaxTestCase): expected_tracer_debug_infos=[ "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x", "traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x", + "traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=b", "None", # TODO(necula): there are missing debug info ])