From d12aead696afb09e21fc6d228361868cf388f318 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 24 Jan 2025 10:57:28 +0200 Subject: [PATCH] [better_errors] Add debug info to more Jaxprs and WrappedFun (step 1) The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry non-None debug info. We change `lu.wrap_init` to construct the result paths thunk whenever it is passed a `debug_info`. The goal is to make sure that all `WrappedFun` have a debug info with result paths support. We change some calling conventions for internal functions to not pass along a separate debug_info if we have a `WrappedFun` or a `Jaxpr`. We obtain several improvements in presence of debug infos in debug_info_test.py --- jax/_src/ad_checkpoint.py | 7 +- jax/_src/api.py | 8 +- jax/_src/api_util.py | 31 +--- jax/_src/checkify.py | 2 +- jax/_src/core.py | 10 +- jax/_src/custom_batching.py | 2 +- jax/_src/custom_dce.py | 11 +- jax/_src/custom_partitioning.py | 4 +- jax/_src/interpreters/ad.py | 15 +- jax/_src/interpreters/batching.py | 3 +- jax/_src/interpreters/mlir.py | 1 - jax/_src/interpreters/partial_eval.py | 47 +++--- jax/_src/interpreters/pxla.py | 23 ++- jax/_src/lax/control_flow/common.py | 2 +- jax/_src/lax/lax.py | 5 +- jax/_src/linear_util.py | 51 +++--- jax/_src/pallas/core.py | 2 +- jax/_src/pallas/pallas_call.py | 21 +-- jax/_src/pjit.py | 70 ++++---- jax/_src/state/discharge.py | 5 +- jax/interpreters/partial_eval.py | 2 +- tests/debug_info_test.py | 228 ++++++++++++-------------- 22 files changed, 270 insertions(+), 280 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index d4bcc8d66..9a6c78fa5 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -420,9 +420,9 @@ def _trace_to_jaxpr(fun: Callable, in_avals: Sequence[core.AbstractValue], debug: core.DebugInfo ) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]: - flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree) + flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun, debug_info=debug), in_tree) try: - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) except core.ConcretizationTypeError as e: msg, = e.args if 'for checkpoint' in msg: @@ -699,7 +699,8 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros): assert next(ins_iter, None) is None with source_info_util.extend_name_stack('rematted_computation'): lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits( - lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False) + lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info), + in_pvals, False) # Transpose the linear jaxpr (which only has linear inputs). out_cts_iter = iter(out_cts_flat) diff --git a/jax/_src/api.py b/jax/_src/api.py index 1c9a706a9..97e2de54d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -61,7 +61,7 @@ from jax._src.api_util import ( flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, apply_flat_fun_nokwargs, check_callable, debug_info, - result_paths, flat_out_axes) + flat_out_axes) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc @@ -1430,7 +1430,8 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple, "pmap", fun, args, kwargs, static_argnums=static_broadcasted_tuple) - f = lu.wrap_init(fun) + f = lu.wrap_init(fun, debug_info=dbg) + del dbg if static_broadcasted_tuple: if max(static_broadcasted_tuple) >= len(args): raise ValueError( @@ -1477,9 +1478,6 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple, raise ValueError(msg) from None local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap") - f, res_paths = result_paths(f) - dbg = dbg.add_result_paths(res_paths) - f = lu.add_debug_info(f, dbg) f, out_axes_thunk = flat_out_axes(f, out_axes) flat_fun, out_tree = flatten_fun(f, in_tree) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 99bbe2886..9e4478245 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -590,7 +590,7 @@ def debug_info( static_argnums: tuple[int, ...] = (), static_argnames: tuple[str, ...] = (), result_paths_thunk: Callable[[], tuple[str, ...]] | None = None, - # TODO(necula): check if we really need this, e.g., to speed up tracing. + # TODO(necula): check if we really need this, e.g., to speed up tracing? sourceinfo: str | None = None, signature: inspect.Signature | None = None, ) -> core.DebugInfo: @@ -674,29 +674,6 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None, arg_names = args_arg_names + kwargs_arg_names return arg_names -@lu.transformation_with_aux2 -def result_paths(_fun, _store, *args, **kwargs): - "linear_util transform to get output pytree paths of pre-flattened function." - ans = _fun(*args, **kwargs) - _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) - return ans - -# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr -def add_jaxpr_debug_info(jaxpr: core.Jaxpr, - debug: core.DebugInfo | None, - result_paths: tuple[str, ...] | None = None, - ) -> core.Jaxpr: - """Add debug info to jaxpr, given trace-time debug info and result paths.""" - if debug is None: - return jaxpr - # TODO(necula): re-enable this safety check - # assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None) - if result_paths is not None: - debug = debug._replace(result_paths=tuple(result_paths)) - else: - debug = debug.resolve_result_paths() - return jaxpr.replace(debug_info=debug) - def hoist_obj_attrs(f, flat_args): idxs, objs, flat_args_ = [], [], [] for i, x in enumerate(flat_args): @@ -721,7 +698,7 @@ def register_class_with_attrs(t: type) -> None: _class_with_attrs: set[type] = set() # TODO(mattjj): make this function faster -def _check_no_aliased_ref_args(dbg, avals, args): +def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args): assert config.mutable_array_checks.value refs: dict[int, int] = {} for i, (a, x) in enumerate(zip(avals, args)): @@ -735,7 +712,7 @@ def _check_no_aliased_ref_args(dbg, avals, args): if dbg else f"at both flat index {dup_idx} and flat index {i}") from None -def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None: +def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None: assert config.mutable_array_checks.value refs: set[int] = {id(core.get_referent(c)) for c in consts if isinstance(core.get_aval(c), AbstractRef)} @@ -746,4 +723,4 @@ def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None: f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable " f"array reference of type {a.str_short()} was both closed over and " f"passed as the argument " - f"{dbg.arg_names[i]}" if dbg else "at flat index {i}") + f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}") diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 80fd1b041..99bb72ffd 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -1206,7 +1206,7 @@ def checkify(f: Callable[..., Out], fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f, debug_info=debug), in_tree) - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug) + jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, ()) jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_)) # checkify: error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts) diff --git a/jax/_src/core.py b/jax/_src/core.py index 5f0d86305..947d1da56 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2369,7 +2369,8 @@ class CallPrimitive(Primitive): def get_bind_params(self, params): new_params = dict(params) jaxpr = new_params.pop('call_jaxpr') - subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ()) + subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), + jaxpr, ()) if config.dynamic_shapes.value: subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr)) return [subfun], new_params @@ -2402,7 +2403,7 @@ class MapPrimitive(Primitive): map_primitive = True def bind_with_trace(self, trace, fun_and_args, params): - fun = fun_and_args[0] + fun: lu.WrappedFun = fun_and_args[0] args = fun_and_args[1:] assert len(params['in_axes']) == len(args) return trace.process_map(self, fun, args, params) @@ -2412,8 +2413,9 @@ class MapPrimitive(Primitive): def get_bind_params(self, params): new_params = dict(params) - jaxpr = new_params.pop('call_jaxpr') - subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ()) + jaxpr: Jaxpr = new_params.pop('call_jaxpr') + subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, + debug_info=jaxpr.debug_info), jaxpr, ()) axes = new_params.pop('out_axes') new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes) return [subfun], new_params diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 040f86944..d4cebc37b 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -153,7 +153,7 @@ class custom_vmap: lu.wrap_init(self.fun, debug_info=debug), in_tree) in_avals = [core.get_aval(x) for x in args_flat] - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) in_tree = treedef_tuple((tree_structure(consts), in_tree)) assert self.vmap_rule is not None diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index 9e666f8f6..9166965b5 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -147,11 +147,11 @@ class custom_dce: ) static_args = [args[i] for i in self.static_argnums] dce_rule = api_util.prepend_static_args( - lu.wrap_init(self.dce_rule), static_args + lu.wrap_init(self.dce_rule, debug_info=debug_rule), static_args ) else: fun = lu.wrap_init(self.fun, debug_info=debug) - dce_rule = lu.wrap_init(self.dce_rule) + dce_rule = lu.wrap_init(self.dce_rule, debug_info=debug_rule) dyn_args = args args_flat, in_tree = tree_util.tree_flatten(dyn_args) @@ -176,7 +176,7 @@ class custom_dce: ) assert self.dce_rule is not None dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic( - flat_rule, in_avals, debug_rule + flat_rule, in_avals ) # This second round of DCE is used to work out which inputs are actually @@ -191,7 +191,7 @@ class custom_dce: return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) out_avals = closed_call.out_avals out_flat = custom_dce_p.bind( @@ -366,7 +366,8 @@ def custom_dce_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, **_): # that most users of this API would compose this with a custom_jvp or # custom_vjp, which makes this less urgent. out = core.call_p.bind( - lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr)), *primals, *tangents + lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr), + debug_info=jvp_jaxpr.jaxpr.debug_info), *primals, *tangents ) out_primals, out_tangents = util.split_list(out, [len(out_nz)]) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index c94066725..e3a1efe43 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -485,13 +485,13 @@ class custom_partitioning: _check_for_tracers(static_args) else: static_args = [] - f_, dyn_args = lu.wrap_init(self.fun), args + f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args args_flat, in_tree = tree_util.tree_flatten(dyn_args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree) in_avals = [core.get_aval(x) for x in args_flat] mesh = mesh_lib.thread_resources.env.physical_mesh with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 668d2a078..2d031d4b9 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -98,7 +98,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params): nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment] - jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) + jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None) residual_avals = map(get_aval, consts) if attrs_tracked: raise NotImplementedError("TODO: attrs") @@ -166,16 +166,17 @@ def _linearize_jaxpr( out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans)) del lin_trace, ans, tracers, new_arg + debug_info = jaxpr.jaxpr.debug_info nzs_out = [type(t) is not Zero for t in out_tangents] out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t) for (nz, t) in zip(nzs_out, out_tangents) if nz) - tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) + tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) tangent_trace.invalidate() if attrs_tracked: raise NotImplementedError("TODO: attrs") residuals_and_primals = (*tangent_consts, *out_primals) residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment] - primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals) + primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) primal_trace.invalidate() num_residuals = len(tangent_consts) tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr)) @@ -207,7 +208,7 @@ def direct_linearize(traceable: lu.WrappedFun, out_nzs = [type(t) is not Zero for t in out_tangents] out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz] out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore - jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents) + jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) tangent_trace.invalidate() out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else pe.PartialVal.known(zeros_like_aval(t.aval)) @@ -1019,12 +1020,14 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], instantiate: Sequence[bool]): assert len(jaxpr.in_avals) == len(nonzeros) - f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) + debug_info = jaxpr.jaxpr.debug_info + f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info) f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) - jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) + jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic( + f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() @lu.transformation_with_aux2 diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index a51ee450a..3828f5392 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -760,7 +760,8 @@ def _batch_jaxpr2( axis_data, in_axes: tuple[int | NotMapped | RaggedAxis, ...], ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]: - f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) + f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr), + debug_info=closed_jaxpr.jaxpr.debug_info) f, out_axes = _batch_jaxpr_inner(f, axis_data) f = _batch_jaxpr_outer(f, axis_data, in_axes) in_axes2, avals_in = unzip2([ diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a9805929f..1d0d5a7c5 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1393,7 +1393,6 @@ def lower_jaxpr_to_fun( MLIR func op """ util.test_event("lower_jaxpr_to_fun", name) - # The first dimension variable may be the platform index num_dim_vars = len(ctx.shape_poly_state.dim_vars) dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 7dd137975..98c5c97a6 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -501,7 +501,7 @@ call_param_updaters[core.closed_call_p] = _closed_call_param_updater def abstract_eval_fun(fun, *avals, debug_info=None, **params): _, avals_out, _, () = trace_to_jaxpr_dynamic( - lu.wrap_init(fun, params), avals, debug_info) + lu.wrap_init(fun, params, debug_info=debug_info), avals) assert all(isinstance(aval, AbstractValue) for aval in avals_out) return avals_out @@ -589,7 +589,7 @@ def trace_to_subjaxpr_nounits( @lu.transformation2 def trace_to_subjaxpr_nounits2( - f, + f: Callable, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -950,7 +950,9 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr, return [*known_vals_out, *residuals] known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk] - jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals) + jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic( + lu.wrap_init(fun, debug_info=f.debug_info), + known_avals) (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking # check jaxpr_known and jaxpr_unknown in isolation @@ -1124,7 +1126,7 @@ def _partial_eval_jaxpr_custom_cached( known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res, known_outvars, known_eqns) jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars, - known_eqns, known_effects) + known_eqns, known_effects, jaxpr.debug_info) config.enable_checks.value and core.check_jaxpr(jaxpr_known) _, ins_staged = partition_list(in_inst, jaxpr.invars) @@ -1336,8 +1338,7 @@ def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr: dbg = jaxpr.debug_info and core.DebugInfo( jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, jaxpr.debug_info.arg_names, - tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)), - used_outputs) if b)) + jaxpr.debug_info.filter_result_paths(used_outputs)) new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg) config.enable_checks.value and core.check_jaxpr(new_jaxpr) return new_jaxpr @@ -1424,10 +1425,8 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...], dbg = jaxpr.debug_info and core.DebugInfo( jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, - tuple(v for v, b in zip(jaxpr.debug_info.safe_arg_names(len(used_inputs)), - used_inputs) if b), - tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)), - used_outputs) if b)) + jaxpr.debug_info.filter_arg_names(used_inputs), + jaxpr.debug_info.filter_result_paths(used_outputs)) new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg) config.enable_checks.value and core.check_jaxpr(new_jaxpr) @@ -1644,7 +1643,9 @@ class JaxprStackFrame: def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer] + def to_jaxpr(self, trace: DynamicJaxprTrace, + out_tracers: Sequence[Tracer], + debug_info: core.DebugInfo | None, ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) @@ -1657,7 +1658,8 @@ class JaxprStackFrame: outvars = state_outvars + explicit_outvars constvars, constvals = unzip2(self.constvar_to_val.items()) jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns) - jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects) + jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, + debug_info) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] @@ -1950,7 +1952,7 @@ class DynamicJaxprTrace(core.Trace): for a, in_axis in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]): jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic( - f, reduced_in_avals, f.debug_info) + f, reduced_in_avals) ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) if ordered_effects: raise ValueError("Ordered effects not supported for " @@ -2074,8 +2076,9 @@ class DynamicJaxprTrace(core.Trace): self.frame.add_eqn(eqn) return out_tracers - def to_jaxpr(self, out_tracers: Sequence[Tracer]): - return self.frame.to_jaxpr(self, out_tracers) + def to_jaxpr(self, out_tracers: Sequence[Tracer], + debug_info: core.DebugInfo | None): + return self.frame.to_jaxpr(self, out_tracers, debug_info) custom_staging_rules: dict[Primitive, Callable] = {} @@ -2116,14 +2119,12 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): def trace_to_jaxpr_dynamic( fun: lu.WrappedFun, in_avals: Sequence[AbstractValue], - debug_info: core.DebugInfo | None = None, *, keep_inputs: list[bool] | None = None, ) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs - - trace = DynamicJaxprTrace(debug_info) + trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] @@ -2131,8 +2132,8 @@ def trace_to_jaxpr_dynamic( ans = fun.call_wrapped(*in_tracers) out_tracers = map(trace.to_jaxpr_tracer, ans) - _check_no_returned_refs(debug_info, out_tracers) - jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers) + _check_no_returned_refs(fun.debug_info, out_tracers) + jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) del trace, fun, in_tracers, out_tracers, ans config.enable_checks.value and core.check_jaxpr(jaxpr) @@ -2160,7 +2161,7 @@ def _check_no_returned_refs( origin_info = ('\n\nThe returned mutable array was created on line ' f'{source_info_util.summarize(eqn.source_info)}.') elif v in frame.invars: - arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore + arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)] # type: ignore origin_info = ('\n\nThe returned mutable array was passed in as the ' f'argument {arg_name}.') else: @@ -2172,10 +2173,10 @@ def _check_no_returned_refs( @profiler.annotate_function def trace_to_jaxpr_dynamic2( - fun: lu.WrappedFun, debug_info: core.DebugInfo | None = None + fun: lu.WrappedFun, ) -> tuple[Jaxpr, OutputType, list[Any]]: - trace = DynamicJaxprTrace(debug_info) + trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): in_avals, keep_inputs = unzip2(fun.in_type) in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 04c950aef..0ce2d0c46 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -33,7 +33,6 @@ import numpy as np import jax from jax._src import api -from jax._src import api_util from jax._src import compiler from jax._src import config from jax._src import core @@ -652,7 +651,6 @@ class ParallelCallableInfo: in_axes: Iterable[int | None] out_axes_thunk: Callable[[], Sequence[int | None]] avals: Sequence[core.AbstractValue] - debug_info: core.DebugInfo | None @cached_property def local_devices(self): @@ -723,8 +721,7 @@ def stage_parallel_callable( "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic( - fun, sharded_avals, pci.debug_info) - jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info) + fun, sharded_avals) assert len(out_sharded_avals) == len(pci.out_axes), ( len(out_sharded_avals), len(pci.out_axes)) @@ -758,7 +755,7 @@ def get_pmap_jaxpr( pci = ParallelCallableInfo( name, backend, axis_name, axis_size, global_axis_size, devices, - in_axes, out_axes_thunk, avals, fun.debug_info) + in_axes, out_axes_thunk, avals) with core.extend_axis_env_nd([(axis_name, axis_size)]): jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) @@ -992,7 +989,7 @@ class UnloadedPmapExecutable: return PmapExecutable( self.compiled, self.build_execute_fun, fingerprint, - self.local_input_avals, self.jaxpr_debug_info, self) + self.local_input_avals, self) @staticmethod def from_hlo(hlo: ir.Module, @@ -1119,24 +1116,23 @@ class UnloadedPmapExecutable: class PmapExecutable(stages.XlaExecutable): __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", - "fingerprint", "in_avals", "_jaxpr_debug_info", - "_unloaded_executable"] + "fingerprint", "in_avals", "_unloaded_executable"] def __init__(self, xla_executable, build_unsafe_call, fingerprint, - in_avals, jaxpr_debug_info, unloaded_executable): + in_avals, + unloaded_executable: UnloadedPmapExecutable): self.xla_executable = xla_executable self._unsafe_call = None self.build_unsafe_call = build_unsafe_call self.fingerprint = fingerprint self.in_avals = in_avals - self._jaxpr_debug_info = jaxpr_debug_info self._unloaded_executable = unloaded_executable @property def unsafe_call(self) -> Callable[..., Any]: if self._unsafe_call is None: self._unsafe_call = self.build_unsafe_call() - return self._unsafe_call + return self._unsafe_call # type: ignore # -- stages.XlaExecutable overrides @@ -1147,7 +1143,8 @@ class PmapExecutable(stages.XlaExecutable): def call(self, *args): # TODO(frostig): do we need to check sharding and sharded avals? arg_avals = map(core.abstractify, args) - check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info) + check_arg_avals_for_call(self.in_avals, arg_avals, + self._unloaded_executable.jaxpr_debug_info) return self.unsafe_call(*args) # pylint: disable=not-callable @@ -3206,7 +3203,7 @@ def check_arg_avals_for_call(ref_avals, arg_avals, f"but called with {len(arg_avals)}") if jaxpr_debug_info is not None: - arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names] + arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))] else: num_args = len(ref_avals) arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)] diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 8844e7f3d..73827f1bc 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -58,7 +58,7 @@ def _initial_style_open_jaxpr(fun: Callable, lu.wrap_init(fun, debug_info=debug_info), in_tree) jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - wrapped_fun, in_avals, debug_info) + wrapped_fun, in_avals) return jaxpr, consts, out_tree(), attrs_tracked @weakref_lru_cache diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9946a9a67..315e01c8d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -746,8 +746,9 @@ def _trace_composite_to_jaxpr(fun: Callable, in_avals: Sequence[core.AbstractValue], name: str, debug_info: core.DebugInfo): - flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info) + flat_fun, out_tree = api_util.flatten_fun_nokwargs( + lu.wrap_init(fun, debug_info=debug_info), in_tree) + jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError( "Found a JAX Tracer as a constant in the decomposition for the " diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 018ab41a4..bee179d08 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -63,7 +63,7 @@ data must be immutable, because it will be stored in function memoization tables """ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import partial from typing import Any, NamedTuple import weakref @@ -71,6 +71,7 @@ import weakref from jax._src import config from jax._src import core from jax._src import traceback_util +from jax._src.tree_util import keystr, generate_key_paths from jax._src.util import curry, cache_clearing_funs, HashableFunction @@ -275,13 +276,6 @@ class DebugInfo(NamedTuple): # e.g. ('[0]', '[1]', ...) result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None - def add_result_paths(self, - result_paths_thunk: Callable[[], tuple[str, ...]] - ) -> DebugInfo: - assert self.result_paths is None - return self._replace(result_paths=HashableFunction(result_paths_thunk, - closure=())) - def resolve_result_paths(self) -> DebugInfo: """Return a debug info with resolved result paths.""" if callable(self.result_paths): @@ -296,6 +290,10 @@ class DebugInfo(NamedTuple): # TODO(necula): this should not happen return (None,) * expected + def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str | None, ...]: + """Keep only the arg_names for which `keep` is True.""" + return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b) + def safe_result_paths(self, expected: int) -> tuple[str, ...]: """Get the result paths with a safety check.""" assert not callable(self.result_paths), self @@ -305,15 +303,34 @@ class DebugInfo(NamedTuple): # TODO(necula): this should not happen return ("",) * expected + def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]: + """Keep only the result_paths for which `keep` is True.""" + assert not callable(self.result_paths), self + return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b) + def wrap_init(f: Callable, params=None, *, debug_info: DebugInfo | None = None) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" params_dict = {} if params is None else params params = () if params is None else tuple(sorted(params.items())) - return WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info) + fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, None) + if debug_info: + if debug_info.result_paths is None: + fun, result_paths_thunk = _get_result_paths_thunk(fun) + debug_info = debug_info._replace( + result_paths=HashableFunction(result_paths_thunk, closure=())) + fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores, + fun.params, fun.in_type, debug_info) + return fun +@transformation_with_aux2 +def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs): + ans = _fun(*args, **kwargs) + _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) + return ans + def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: assert f.in_type is None if in_type is None: @@ -350,16 +367,9 @@ def _check_input_type(in_type: core.InputType) -> None: provided[d.val] = True assert all(provided) -def add_debug_info(f: WrappedFun, debug_info: DebugInfo | None - ) -> WrappedFun: - """Produce a new WrappedFun with debug_info attached.""" - assert f.debug_info is None - if debug_info is None: - return f - return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info) - -def cache(call: Callable, *, explain: Callable | None = None): +def cache(call: Callable, *, + explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None): """Memoization decorator for functions taking a WrappedFun as first argument. Args: @@ -367,6 +377,9 @@ def cache(call: Callable, *, explain: Callable | None = None): underlying transforms and params on the WrappedFun are used as part of the memoization cache key. + explain: a function that is invoked upon cache misses to log an explanation + of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`. + Returns: A memoized version of ``call``. """ @@ -382,7 +395,7 @@ def cache(call: Callable, *, explain: Callable | None = None): else: ans = call(fun, *args) if explain and config.explain_cache_misses.value: - explain(fun.f, cache is new_cache, cache, key) + explain(fun, cache is new_cache, cache, key) cache[key] = (ans, fun.stores) return ans diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index cc156f375..35488530c 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -423,7 +423,7 @@ class BlockSpec: ) with tracing_grid_env(grid, mapped_dims): jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( - flat_index_map_fun, index_map_avals, debug_info=debug + flat_index_map_fun, index_map_avals ) mapped_block_shape = tuple(mapped if s is None else s for s in block_shape) if len(out_avals) != len(block_shape): diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 98fafce09..5552856bf 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -101,12 +101,12 @@ def _pallas_call_jvp_rule( primals, tangents, *, - jaxpr, + jaxpr: jax_core.Jaxpr, name_and_src_info, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping, - debug, - interpret, + debug: bool, + interpret: bool, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], @@ -1098,13 +1098,14 @@ def pallas_call_checkify_rule(error: checkify.Error, retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval, *error_memref_aval, *output_aval, *scratch_aval] jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) - wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(checked_kernel_fn), jaxpr_in_tree) debug = api_util.debug_info("checkify_pallas", checked_kernel_fn, retrace_in_avals, {}) + wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree) + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - wrapped_kernel_with_err, jaxpr_flat_avals, debug) + wrapped_kernel_with_err, jaxpr_flat_avals) # Prepare pallas_call inputs. We need to create new block specs # for the new error inputs and outputs. @@ -1161,16 +1162,16 @@ def _trace_kernel_to_jaxpr( kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...], indexer: bool = False, ) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]: + fake_kernel_args = kernel_in_tree.unflatten(kernel_avals) + debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {}) wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(fun), kernel_in_tree) + lu.wrap_init(fun, debug_info=debug), kernel_in_tree) wrapped_kernel_fun = primitives.wrap_with_transforms( wrapped_kernel_fun, kernel_in_transforms ) - fake_kernel_args = kernel_in_tree.unflatten(kernel_avals) - debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {}) with grid_mapping.trace_env(): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, - kernel_avals, debug) + kernel_avals) if consts: consts_avals = [jax_core.get_aval(c) for c in consts] if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b435bfdc9..17c20c941 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb from jax._src.api_util import ( argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, donation_vector, check_callable, resolve_argnums, - argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info, + argnames_partial_except, debug_info, hoist_obj_attrs, _check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) from jax._src.interpreters import partial_eval as pe @@ -567,9 +567,7 @@ def _infer_params_impl( axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs) - f = lu.wrap_init(fun) - f, res_paths = result_paths(f) - dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths) + f = lu.wrap_init(fun, debug_info=dbg) f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) del args @@ -618,7 +616,7 @@ def _infer_params_impl( in_shardings_flat, in_layouts_flat = _process_in_axis_resources( in_shardings_treedef, in_shardings_leaves, ji.in_layouts_treedef, ji.in_layouts_leaves, - in_avals, in_tree, dbg, device_or_backend_set, have_kwargs) + in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs) attr_token = _attr_token(flat_fun, in_type) @@ -627,8 +625,7 @@ def _infer_params_impl( if mesh_lib.get_abstract_mesh().empty else mesh_lib.get_abstract_mesh()) with mesh_lib.set_abstract_mesh(abstract_mesh): jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( - flat_fun, in_type, attr_token, dbg, - HashableFunction(res_paths, closure=()), + flat_fun, in_type, attr_token, IgnoreKey(ji.inline)) if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) @@ -1171,17 +1168,18 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, callsites: set[str] = set() def explain_tracing_cache_miss( - f: Callable, unseen_f: bool, cache: dict, key: tuple): + fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple): if config.check_tracer_leaks.value: return def unpack(key): - transforms, (), _, (in_type, _, debug_info, _, inline), *_, ctx = key + transforms, (), _, (in_type, _, inline), *_, ctx = key # TODO(dougalm,mattjj): enable cache miss explanation with attrs _, (_, (in_tree,)), *_ = transforms - return in_tree, in_type, debug_info, inline.val, ctx - in_tree, in_type, debug_info, inline, ctx = unpack(key) + return in_tree, in_type, inline.val, ctx + in_tree, in_type, inline, ctx = unpack(key) if inline: return + debug_info = fun.debug_info msg: list[str] = [] p = msg.append done = lambda: logger.log(logging.WARNING, '\n'.join(msg)) @@ -1190,7 +1188,7 @@ def explain_tracing_cache_miss( p(f"TRACING CACHE MISS at {callsite} because:") # have we seen this function before at all? - fun_name = getattr(f, '__qualname__', f) + fun_name = getattr(fun.f, '__qualname__', fun.f) if debug_info is not None and debug_info.func_src_info: # TODO(necula): clean up the extraction of the source info _, *rest = debug_info.func_src_info.split(' at ') @@ -1198,7 +1196,7 @@ def explain_tracing_cache_miss( else: src_info = '' if unseen_f: - p(f" never seen function:\n {fun_name} id={id(f)}{src_info}") + p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}") if callsite in callsites: p(" but seen another function defined on the same line; maybe the function is\n" " being re-defined repeatedly, preventing caching?") @@ -1263,7 +1261,7 @@ def explain_tracing_cache_miss( # have we never seen these input types (eg shapes, dtypes) before? types_match = [k for k in trees_match if k[1] == in_type] if not types_match: - if len(in_type) < 5: + if len(in_type) < 5 and debug_info is not None: in_type_str = ':\n {}'.format(', '.join( f'{n}: {ty.str_short(short_dtypes=True)}' for n, ty in zip(debug_info.arg_names, in_type))) @@ -1275,7 +1273,12 @@ def explain_tracing_cache_miss( num_mismatch = sum(map(op.ne, closest_ty, in_type)) p(f" closest seen input type signature has {num_mismatch} mismatches, including:") add_weak_type_hint = False - for name, ty1, ty2 in zip(debug_info.arg_names, closest_ty, in_type): + if debug_info: + arg_names = debug_info.safe_arg_names(len(in_type)) + else: + arg_names = (None,) * len(in_type) + + for name, ty1, ty2 in zip(arg_names, closest_ty, in_type): if ty1 != ty2: if type(ty1) == type(ty2) == core.ShapedArray: s1, s2 = ty1.str_short(True), ty2.str_short(True) @@ -1302,8 +1305,6 @@ def _create_pjit_jaxpr( fun: lu.WrappedFun, in_type: core.InputType | Sequence[core.AbstractValue], attr_data: int, - debug_info: core.DebugInfo, - result_paths: Callable, ignored_inline: IgnoreKey ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: @@ -1317,17 +1318,13 @@ def _create_pjit_jaxpr( fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): if config.dynamic_shapes.value: jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( - lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info) + lu.annotate(fun, cast(core.InputType, in_type))) attrs_tracked = [] else: jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - fun, in_type, debug_info=debug_info) + fun, in_type) # assert attr_data is sentinel or attr_data matches attrs_tracked - # TODO(dougalm,mattjj): enable debug info with attrs_tracked - if not config.dynamic_shapes.value and not attrs_tracked: - jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths()) - if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import check_key_reuse_jaxpr @@ -1928,7 +1925,9 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_): pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval) -def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, +def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, + name: str, jaxpr: core.ClosedJaxpr, + effects, in_shardings, out_shardings, in_layouts, out_layouts, api_name): mod_ctx = ctx.module_context @@ -1959,7 +1958,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, return func -def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, +def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str, + jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, keep_unused, inline, compiler_options_kvs): effects = list(ctx.tokens_in.effects()) @@ -1987,8 +1987,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, mlir.register_lowering(pjit_p, _pjit_lowering) -def _pjit_batcher(axis_data, vals_in, dims_in, - jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, +def _pjit_batcher(axis_data, vals_in, + dims_in: tuple[int, ...], + jaxpr: core.ClosedJaxpr, + in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) @@ -2037,7 +2039,8 @@ batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule def _pjit_batcher_for_sharding( s: Sharding | UnspecifiedValue, - dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int): + dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh, + ndim: int): if isinstance(s, UnspecifiedValue): return s hlo_s = s._to_xla_hlo_sharding(ndim) @@ -2049,7 +2052,7 @@ def _pjit_batcher_for_sharding( return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec) new_op = hlo_s.to_proto().clone() tad = list(new_op.tile_assignment_dimensions) - tad.insert(dim, 1) + tad.insert(dim, 1) # type: ignore new_op.tile_assignment_dimensions = tad new_gs = GSPMDSharding( s._device_assignment, new_op, @@ -2171,8 +2174,9 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, ad.primitive_linearizations[pjit_p] = _pjit_linearization -def _pjit_partial_eval(trace, *in_tracers, - jaxpr, in_shardings, out_shardings, +def _pjit_partial_eval(trace: pe.JaxprTrace, + *in_tracers, + jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs): in_pvals = [t.pval for t in in_tracers] @@ -2191,7 +2195,7 @@ def _pjit_partial_eval(trace, *in_tracers, else: known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) - unknown_outs = tuple(unknown_outs) + unknown_outs = tuple(unknown_outs) # type: ignore[assignment] known_outs = tuple(not uk for uk in unknown_outs) num_residuals = len(res_avals) res_shardings = (UNSPECIFIED,) * num_residuals @@ -2282,7 +2286,7 @@ def _pjit_partial_eval(trace, *in_tracers, unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()] unknown_out_avals = unknown_jaxpr.out_avals unknown_tracers_out = [ - pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) + pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) # type: ignore for aval in unknown_out_avals ] eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers), diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 3f8d43707..dc781c424 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -998,9 +998,10 @@ def _initial_style_jaxpr(fun: Callable, in_tree: api_util.PyTreeDef, in_avals: Sequence[core.AbstractValue], debug: core.DebugInfo): - fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), + fun_, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(fun, debug_info=debug), tree_util.treedef_tuple((in_tree,))) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals) return jaxpr, consts, out_tree_thunk() diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 39735ae5e..7750c1837 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -88,7 +88,7 @@ from jax._src.interpreters.partial_eval import ( # TODO(mattjj): remove temporary shim when trace_to_jaxpr_dynamic sig stabilizes def trace_to_jaxpr_dynamic(fun, in_avals, debug_info=None, *, keep_inputs=None): # noqa jaxpr, out_avals, consts, () = _trace_to_jaxpr_dynamic( - fun, in_avals, debug_info, keep_inputs=keep_inputs) + fun, in_avals, keep_inputs=keep_inputs) return jaxpr, out_avals, consts diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 4780c703c..da98d81d5 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -14,9 +14,7 @@ from __future__ import annotations -import contextlib import functools -import math import operator import re from typing import Any @@ -74,12 +72,12 @@ def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]: # Strip the absolute path and the line number but check that it references # this file (to catch errors when the source info points in JAX internals) fun_src_info = re.sub(r"^(\S+)( at .*/debug_info_test.py:.*)?", "\\1", dbg.func_src_info) - res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={','.join(dbg.arg_names)}" + arg_names_str = ",".join([str(a) for a in dbg.arg_names]) + res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={arg_names_str}" if isinstance(dbg.result_paths, tuple): - if dbg.result_paths: - res += f", result_paths={','.join(dbg.result_paths)}" - else: - res += ", result_paths=" + res += f", result_paths={','.join(dbg.result_paths)}" + elif dbg.result_paths is None: + res += ", result_paths=" return res @@ -151,7 +149,8 @@ class DebugInfoTest(jtu.JaxTestCase): found_jaxprs_debug_infos = [_debug_info_to_string(j.debug_info) for j in all_jaxprs] - self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos) # JAXPRS + self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos, + "Jaxprs debug_infos") # JAXPRS found_tracer_debug_infos = [] if tracer_spy is not None: @@ -173,7 +172,8 @@ class DebugInfoTest(jtu.JaxTestCase): else: found_tracer_debug_infos.append("None") - self._check_matches(expected_tracer_debug_infos, found_tracer_debug_infos) # INSPECTED TRACERS + self._check_matches(expected_tracer_debug_infos, found_tracer_debug_infos, + "Tracer debug_infos") # INSPECTED TRACERS if not check_lowering: return # Collect all the lines in all the MLIR modules @@ -186,36 +186,34 @@ class DebugInfoTest(jtu.JaxTestCase): mlir_modules_lines.extend( mlir.module_to_string(mod, enable_debug_info=True).split("\n")) - expected_and_found = set() - expected_and_not_found = set() - for exp in expected_lowering_lines: - for l in mlir_modules_lines: - ok = exp.match(l) if isinstance(exp, re.Pattern) else exp == l - if ok: - expected_and_found.add(exp) - break - else: - expected_and_not_found.add(exp) - - if expected_and_not_found: - msg = "\n".join(mlir_modules_lines) - self.assertEmpty(expected_and_not_found, "\nNot found in the MLIR module lines:\n" + msg) + self._check_matches(expected_lowering_lines, mlir_modules_lines, + "MLIR module lines", report_found_unexpected=False) def _check_matches(self, expected: list[str | re.Pattern], - found: list[str]): - expected_and_found = set() - unexpected: set[str] = set() - for debug_info in found: - for exp_re in expected: - ok = exp_re.match(debug_info) if isinstance(exp_re, re.Pattern) else exp_re == debug_info + found: list[str], + what: str, + report_found_unexpected: bool = True): + expected_and_found: set[str | re.Pattern] = set() + found_and_expected: set[str] = set() + for exp_re in expected: + for found_line in found: + ok = exp_re.match(found_line) if isinstance(exp_re, re.Pattern) else exp_re == found_line if ok: expected_and_found.add(exp_re) - break - else: - unexpected.add(debug_info) - self.assertEmpty(unexpected) # found unexpected debug_info - self.assertEmpty([e for e in expected if e not in expected_and_found]) # expected element that was not found + found_and_expected.add(found_line) + + found_and_unexpected = set(found) - found_and_expected + all_found = "\n ".join(found) + if report_found_unexpected and found_and_unexpected: + unexp_str = "\n ".join(found_and_unexpected) + msg = f"Found unexpected {what}:\n {unexp_str}\nAll found {what}:\n {all_found}" + self.assertTrue(False, msg) + + if expected_not_found := {e for e in expected if e not in expected_and_found}: + exp_str = "\n ".join([str(e) for e in expected_not_found]) + msg = f"Expected but not found in {what}:\n {exp_str}\nAll found {what}:\n {all_found}" + self.assertTrue(False, msg) def test_debug_info_basic(self): def my_f(x, y, z, w): @@ -634,8 +632,7 @@ class DebugInfoTest(jtu.JaxTestCase): 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - # TODO(necula): bad result names - 'traced_for=jit, fun=my_f, arg_names=a, result_paths=', + 'traced_for=jit, fun=my_f, arg_names=a, result_paths=', 'traced_for=jit, fun=my_g, arg_names=b, result_paths=', ], check_tracer_arg_name=True, @@ -694,7 +691,7 @@ class DebugInfoTest(jtu.JaxTestCase): check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']", - "None", # TODO(necula) + "None", # TODO(necula) missing debug info ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"y\['hi'\]\"\)"), @@ -780,7 +777,8 @@ class DebugInfoTest(jtu.JaxTestCase): lambda x, y, z: jax.jvp(jax.jit(f), (x, y, z), (x, y, z)), jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ - "None", # TODO(necula): missing debug info + # TODO(necula): arg_names, result_paths + "traced_for=jit, fun=f, arg_names=None,None,None,None, result_paths=,,,", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ @@ -793,7 +791,7 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(unknown\)"), re.compile(r".*func.func public @main\(.*%arg3: tensor loc\(unknown\)"), # TODO(necula): missing result names - re.compile(r".*func.func public @main\(.*-> \(tensor, tensor, tensor, tensor\) {"), + re.compile(r".*func.func public @main\(.*-> .*tensor {jax.result_info = \"\"}"), ]) def test_vjp_of_jit(self): @@ -805,6 +803,7 @@ class DebugInfoTest(jtu.JaxTestCase): lambda x, y, z: jax.vjp(jax.jit(my_f), x, y, z)[1](dict(a=x, b=[y])), jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=", "None", # TODO(necula): missing debug info ], tracer_spy=tracer_spy, @@ -816,8 +815,7 @@ class DebugInfoTest(jtu.JaxTestCase): # TODO(necula): missing arg_names re.compile(r".*func.func public @main\(%arg0: tensor loc\(unknown\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(unknown\)"), - # TODO(necula): missing result names - re.compile(r".*func.func public @main\(.*-> tensor {"), + re.compile(r".*func.func public @main\(.*-> \(tensor {jax.result_info = \"\"}"), ]) def test_vjp_of_nested_jit(self): @@ -837,6 +835,8 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=[0],[1]", + # TODO(necula): result_paths + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", # TODO(necula): missing debug info "None", ], @@ -872,8 +872,7 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=", - # TODO(necula): missing debug info - 'None', + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=['c']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info @@ -930,8 +929,9 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - # TODO(necula): some Jaxprs without debug info - "None"], + "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=", + "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=", + ], expected_tracer_debug_infos=[ "traced_for=cond, fun=my_true_branch, arg_names=a,b", "traced_for=cond, fun=my_false_branch, arg_names=c,d" @@ -957,8 +957,10 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - # TODO(necula): some Jaxprs without debug info - "None"], + "traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=", + "traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=", + "traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=", + ], expected_tracer_debug_infos=[ "traced_for=switch, fun=my_branch0, arg_names=x0", "traced_for=switch, fun=my_branch1, arg_names=x1", @@ -1031,6 +1033,8 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]", + # TODO(necula): bad result paths + "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", 'None', # TODO(necula): some Jaxprs without debug info ], expected_tracer_debug_infos=[ @@ -1038,6 +1042,14 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=scan, fun=f, arg_names=c,a", "traced_for=jit, fun=my_f, arg_names=x,as_", 'None', # TODO(necula): some missing debug info + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"c\"\)"), + re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"), + re.compile(r".*func.func public @main\(.* -> .*tensor {jax.result_info = \"\[0\]\""), + re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"\[1\]\""), + # TODO(necula): unnamed function? + re.compile(r".*func.func private @None"), ]) def test_while_loop(self): @@ -1059,7 +1071,8 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - 'None', # TODO(necula): some missing debug info + 'traced_for=while_body, fun=my_body, arg_names=b, result_paths=', + 'traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=', ], check_tracer_arg_name=True, expected_tracer_debug_infos=[ @@ -1080,7 +1093,9 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=", - 'None', # TODO(necula): some missing debug info + # TODO(necula): bad arg_names, result_paths + 'traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=[0][0],[0][1]', + ], expected_tracer_debug_infos=[ # TODO(necula): the arg_names are not right @@ -1097,7 +1112,9 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=ub,x, result_paths=", - 'None', # TODO(necula): some missing debug info + re.compile(r'traced_for=while_cond, fun=_fori_cond_fun at .*/loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths='), + # TODO(necula): arg_names and result_paths are not right + "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=[0],[1],[2]", ], expected_tracer_debug_infos=[ # TODO(necula): the arg_names are not right @@ -1119,10 +1136,11 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=[0],[1]", - # TODO(necula): some Jaxprs without debug info - 'None'], + "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=[0],[1]", + ], + check_tracer_arg_name=True, expected_tracer_debug_infos=[ - "traced_for=scan, fun=my_scan_body, arg_names=carry,inp" + "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, from carry" ]) def test_eval_shape(self): @@ -1179,6 +1197,17 @@ class DebugInfoTest(jtu.JaxTestCase): expected_tracer_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", ], + expected_lowering_lines=[ + # TODO(necula): we did not DCE y? + re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['u'\]\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['v'\]\"\}"), + ] ) def test_pmap_of_grad(self): @@ -1243,7 +1272,7 @@ class DebugInfoTest(jtu.JaxTestCase): x, x_tan, expected_jaxpr_debug_infos=[ 'traced_for=jit, fun=, arg_names=x,x_tan, result_paths=[0],[1]', - "None", # TODO(necula): missing debug info + "traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ @@ -1268,10 +1297,12 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - # TODO(necula): some Jaxprs without debug info - 'None'], + # TODO(necula): missing result_paths + "traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=", + ], + check_tracer_arg_name=True, expected_tracer_debug_infos=[ - "traced_for=checkpoint / remat, fun=my_g, arg_names=y" + "traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y" ]) def test_grad_remat(self): @@ -1360,8 +1391,8 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=", - # TODO(necula): some Jaxprs without debug info - 'None'], + "traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=[0],[1]", + ], expected_tracer_debug_infos=[ # TODO(necula): no leaked tracer from my_g_dce? "traced_for=custom_dce, fun=my_g, arg_names=x", @@ -1388,8 +1419,9 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=", - # TODO(necula): some Jaxprs without debug info - 'None'], + # TODO(necula): bad arg_names (why None), bad result_paths + 'traced_for=custom_dce, fun=my_f, arg_names=None,x, result_paths=[0],[1]', + ], check_tracer_arg_name=True, expected_tracer_debug_infos=[ # TODO(necula): no leaked tracer from my_rule? @@ -1427,7 +1459,14 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r"traced_for=jit, fun=_solve at .*scipy/linalg.py:.*, arg_names=a,b, result_paths="), re.compile(r"traced_for=jit, fun=solve at .*/linalg.py:.*, arg_names=a,b, result_paths="), re.compile(r"traced_for=jit, fun=_lu_solve at .*/linalg.py:.*, arg_names=lu,permutation,b, result_paths="), - "None", # TODO(necula): there are missing jaxpr debug info + # TODO(necula): why pointers to internal functions, arg_names, result_paths? + re.compile(r'traced_for=custom_linear_solve solve, fun= at .*linalg.py:.*, arg_names=None,None,x, result_paths='), + re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*/linalg.py:.*, arg_names=None,None,x, result_paths='), + re.compile(r'traced_for=custom_linear_solve, fun= at .*/linalg.py:.*, arg_names=None,x, result_paths='), + re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*/linalg.py:.*, arg_names=None,x, result_paths='), + 'traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=None,b, result_paths=', + 'traced_for=custom_linear_solve solve, fun=my_solve, arg_names=None,x, result_paths=', + 'traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=None,x, result_paths=', ], expected_tracer_debug_infos=[ "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x", @@ -1490,8 +1529,9 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - # TODO(necula): missing Jaxpr debug info - "None"], + "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=[0],[1]", + "traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, result_paths=", + ], expected_tracer_debug_infos=[ "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j", "traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref", @@ -1519,7 +1559,10 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=input, result_paths=", - "None", # TODO(necula): missing tracer debug info + # TODO(necula): function source location points in JAX internals + # TODO(necula): arg_names and result_paths are wrong + re.compile(r"traced_for=checkify_pallas, fun=checked_kernel_fn at .*/pallas_call.py:.*, arg_names=args\[0\],.*, result_paths="), + re.compile(r"traced_for=pallas_call index_map, fun= at .*/pallas/core.py:.*, arg_names=, result_paths="), ], expected_tracer_debug_infos=[ "traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref", @@ -1543,64 +1586,11 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_consts, arg_names=x, result_paths=", - "None" + "traced_for=composite, fun=my_consts, arg_names=x, result_paths=", ], expected_tracer_debug_infos=[ "traced_for=composite, fun=my_consts, arg_names=x"]) -class EagerPmapMixin: - - def setUp(self): - super().setUp() - stack = contextlib.ExitStack() - stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True)) - stack.enter_context(jtu.ignore_warning( - message="Some donated buffers were not usable", category=UserWarning)) - self.addCleanup(stack.close) - - -@jtu.pytest_mark_if_available('multiaccelerator') -class PythonPmapEagerTest(EagerPmapMixin, jtu.JaxTestCase): - def test_pmap_lower_arg_info(self): - def f(x, y, *args, **kwargs): - return y['hi'] + args[1] + sum(kwargs.values()) - - lowered = jax.pmap(f).lower( - {'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]), - jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.])) - hlo_str = lowered.as_text("stablehlo", debug_info=True) - self.assertNotIn("\"x\"", hlo_str) - self.assertIn("y['hi']", hlo_str) - self.assertIn("args[0]", hlo_str) - self.assertIn("args[1]", hlo_str) - self.assertIn("kwargs['z']", hlo_str) - self.assertIn("kwargs['w']", hlo_str) - - def test_pmap_lower_result_info(self): - def f(x, y, z): - return {'a': x, 'b': [y]} - - lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),), - [jnp.array([3])]) - hlo_str = lowered.as_text("stablehlo", debug_info=True) - self.assertIn("jax.result_info = \"['a']\"", hlo_str) - self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str) - - def testLowerCompileArgTypeMismatch(self): - f = jax.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') - shape = (jax.device_count(), 4) - x = np.arange(math.prod(shape), dtype=int).reshape(shape) - x_f32 = x.astype(jnp.float32) - x_i32 = x.astype(jnp.int32) - f_exe = f.lower(x_f32).compile() - self.assertRaisesRegex( - TypeError, - r"Argument types differ .*" - r"The mismatches are:\n" - r"Argument 'x' compiled with.*float32.*and called with.*int32.*", - lambda: f_exe(x_i32)) - - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())