From 32c98b9a7694f76fd973ee156f464b9ba0c53185 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 25 Jan 2025 18:34:38 +0200 Subject: [PATCH] [better_errors] Refactor more uses of pe.tracing_debug_info (part 3) We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`, which uses the actual args and kwargs, instead of `in_tree` to manufacture fake args and kwargs. This ends up being more accurate, especially for `arg_names`; see changes in debug_info_tests.py. This means that we have to construct the debug info further upstream, before flattening args. This will later help populate debug info in `WrappedFun` and `Jaxpr`. This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify, and the custom_partitioning (the last few uses). In order to land this, I had to remove a safety check that the number of `arg_names` and `result_paths` in a Jaxpr's debug info match the number of Jaxpr invars and outvars, respectively. Additionally, I added two accessors `safe_arg_names` and `safe_result_paths` to ensure that the arg names and result paths match the expected length. These accessors return no-op results when the lengths are not as expected. From my testint, this happens only in Jaxprs that are not used for lowering, hence there is no actual user-visible change here. Simply, more internal Jaxprs are getting debug_info and in some cases the `arg_names` and `result_paths` are not correct. Still, this change is worth it because the `func_src_info` is the most useful part of the debug info (used for leaked tracers), and that is accurate. We will fix the `arg_names` and `result_paths` in a future change. One can see in the changes in debug_info_test.py the improvements in the user-visible debug info, including for `pjit` and `pmap` cases when it was wrong. --- jax/_src/api.py | 11 +++--- jax/_src/api_util.py | 37 ++------------------ jax/_src/checkify.py | 8 +++-- jax/_src/core.py | 21 +++++++++-- jax/_src/custom_partitioning.py | 5 +-- jax/_src/interpreters/partial_eval.py | 50 ++++----------------------- jax/_src/interpreters/pxla.py | 15 ++++---- jax/_src/linear_util.py | 21 +++++++++-- jax/_src/pjit.py | 21 +++++------ jax/interpreters/partial_eval.py | 2 -- tests/debug_info_test.py | 43 +++++++++++++++++------ 11 files changed, 112 insertions(+), 122 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 737941716..3c9c4d9f7 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -61,7 +61,7 @@ from jax._src.api_util import ( flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, apply_flat_fun_nokwargs, check_callable, tracing_debug_info, - result_paths, flat_out_axes, debug_info_final) + result_paths, flat_out_axes) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc @@ -1420,15 +1420,15 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str, return global_axis_size -def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, +def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, in_devices, backend_name, axis_size, args, kwargs): if in_devices is not None and len(in_devices) == 0: raise ValueError("'devices' argument to pmap must be non-empty, or None.") dbg = tracing_debug_info( - 'pmap', fun, args, kwargs, - static_argnums=static_broadcasted_tuple) + "pmap", fun, args, kwargs, + static_argnums=static_broadcasted_tuple) f = lu.wrap_init(fun) if static_broadcasted_tuple: @@ -1478,9 +1478,10 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap") f, res_paths = result_paths(f) + dbg = dbg.add_result_paths(res_paths) + f = lu.add_debug_info(f, dbg) f, out_axes_thunk = flat_out_axes(f, out_axes) flat_fun, out_tree = flatten_fun(f, in_tree) - flat_fun = debug_info_final(flat_fun, dbg, res_paths) is_explicit_global_axis_size = axis_size is not None global_axis_size = _get_global_axis_size(local_axis_size, in_devices, diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 5d7e4b3d8..029131176 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -100,29 +100,6 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args): ans = fun(*args) return tree_unflatten(out_tree, ans) -def flattened_fun_in_tree( - fn: lu.WrappedFun - ) -> tuple[PyTreeDef, Callable[[], PyTreeDef], bool] | None: - # This implementation relies on internal details of linear_util.py's - # WrappedFun, but it's for the worthy cause of better user error messages. - # It can fail (i.e. return None) if its WrappedFun argument is not transformed - # with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when - # core.eval_jaxpr encounters a call primitive (though at that point we're just - # round-tripping jaxprs and the user errors in question are impossible). - assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1 - assert (isinstance(flatten_fun_nokwargs, partial) and - len(flatten_fun_nokwargs.args) == 1) - flattens = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]} - try: - ((in_tree,), out_tree_store, has_kwargs), = ( - (args, store, f is flatten_fun.args[0]) - for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens) - except ValueError: - # When `fn` is not the result of flatten_fun or flatten_fun_nokwargs - return None - else: - return in_tree, lambda: out_tree_store.val, has_kwargs # type: ignore[union-attr] - @lu.transformation_with_aux2 def flatten_fun_nokwargs2(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) @@ -705,6 +682,7 @@ def result_paths(_fun, _store, *args, **kwargs): _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) return ans +# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr def add_jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, result_paths: tuple[str, ...] | None = None, @@ -712,7 +690,8 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr, """Add debug info to jaxpr, given trace-time debug info and result paths.""" if trace_debug is None: return jaxpr - assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None) + # TODO(necula): re-enable this safety check + # assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None) if result_paths is None: result_paths = trace_debug.result_paths_thunk() # type: ignore debug_info = core.JaxprDebugInfo( @@ -720,16 +699,6 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug.arg_names, tuple(result_paths)) # type: ignore return jaxpr.replace(debug_info=debug_info) -def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None, - res_paths_thunk: Callable[[], tuple[str, ...]] - ) -> lu.WrappedFun: - "Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun" - if dbg is None: return f - assert dbg.result_paths_thunk is None - res_paths_thunk_ = HashableFunction(res_paths_thunk, closure=()) - return lu.add_debug_info(f, dbg._replace(result_paths_thunk=res_paths_thunk_)) - - def hoist_obj_attrs(f, flat_args): idxs, objs, flat_args_ = [], [], [] for i, x in enumerate(flat_args): diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 8f3e4c3fb..f4fe0edbf 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -27,6 +27,7 @@ from jax import lax from jax.experimental import shard_map from jax._src import api +from jax._src import api_util from jax._src import ad_checkpoint from jax._src import linear_util as lu from jax._src import config @@ -39,7 +40,6 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util as jtu from jax._src.ad_util import SymbolicZero -from jax._src.api_util import flatten_fun from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -1202,8 +1202,10 @@ def checkify(f: Callable[..., Out], in_tree = jtu.tree_structure(((), {})) closed_f = lambda: f(*args, **kwargs) # stage: - fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree) - debug = pe.tracing_debug_info(closed_f, in_tree, out_tree, False, 'checkify') + debug = api_util.tracing_debug_info("checkify", f, args, kwargs) + fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f, + debug_info=debug), + in_tree) jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug) jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_)) # checkify: diff --git a/jax/_src/core.py b/jax/_src/core.py index ae3059517..f5264dabf 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -91,6 +91,22 @@ class JaxprDebugInfo(NamedTuple): # This is formed after tracing, when we have concrete `result_paths` result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...) + def safe_arg_names(self, expected: int) -> tuple[str | None, ...]: + """Get the arg_names with a safety check.""" + if len(self.arg_names) == expected: + return self.arg_names + else: + # TODO(necula): this should not happen + return (None,) * expected + + def safe_result_paths(self, expected: int) -> tuple[str | None, ...]: + """Get the result_paths with a safety check.""" + if len(self.result_paths) == expected: + return self.result_paths + else: + # TODO(necula): this should not happen + return ("",) * expected + class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', @@ -149,8 +165,9 @@ class Jaxpr: self._eqns = list(eqns) self._effects = effects self._debug_info = debug_info - assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars) - assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) + # TODO(necula): re-enable these safety checks + # assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars) + # assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) def __str__(self): return str(self.pretty_print()) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 697c9c4d7..459faeea9 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -473,6 +473,9 @@ class custom_partitioning: def __call__(self, *args, **kwargs): args = _resolve_kwargs(self.fun, args, kwargs) + debug = api_util.tracing_debug_info("custom_partitioning", self.fun, + args, kwargs, + static_argnums=self.static_argnums) if self.static_argnums: static_argnums = set(self.static_argnums) args = tuple(x if i in static_argnums else x for i, x in enumerate(args)) @@ -491,8 +494,6 @@ class custom_partitioning: args_flat, in_tree = tree_util.tree_flatten(dyn_args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree) in_avals = [core.get_aval(x) for x in args_flat] - debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False, - "custom_partitioning") mesh = mesh_lib.thread_resources.env.physical_mesh with core.extend_axis_env_nd(mesh.shape.items()): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index dd81d8d4a..3777ce3d3 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -35,7 +35,6 @@ from jax._src import profiler from jax._src import source_info_util from jax._src import compute_on from jax._src import xla_metadata as xla_metadata_lib -from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs) from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, Var, DropVar, Atom, @@ -44,7 +43,7 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.state.types import AbstractRef from jax._src import tree_util -from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten, +from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, @@ -1913,8 +1912,7 @@ class DynamicJaxprTrace(core.Trace): implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation - dbg = tracing_debug_info_final(f, call_primitive.name) - jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f, debug_info=dbg) + jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f) if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) @@ -1940,7 +1938,8 @@ class DynamicJaxprTrace(core.Trace): self.frame.add_eqn(eqn) return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] - def process_map(self, map_primitive, f, tracers, params): + def process_map(self, map_primitive, f: lu.WrappedFun, + tracers: Sequence[core.Tracer], params): tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] @@ -1949,8 +1948,7 @@ class DynamicJaxprTrace(core.Trace): for a, in_axis in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]): jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic( - f, reduced_in_avals, - debug_info=tracing_debug_info_final(f, map_primitive.name)) + f, reduced_in_avals, f.debug_info) ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) if ordered_effects: raise ValueError("Ordered effects not supported for " @@ -2050,7 +2048,7 @@ class DynamicJaxprTrace(core.Trace): closed_call_jaxpr = core.ClosedJaxpr( convert_constvars_jaxpr(call_jaxpr), ()) - transpose_flat, in_tree2 = flatten_fun_nokwargs( + transpose_flat, in_tree2 = api_util.flatten_fun_nokwargs( lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree))) # the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts @@ -2111,42 +2109,6 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): store.store(out_zeros) return [*out_primals, *out_nz_tangents] -# Callers should be using linear_util.debug_info instead! -def tracing_debug_info( - fn: Callable, - in_tree: PyTreeDef, - out_tree_thunk: Callable[[], PyTreeDef], - has_kwargs: bool, - traced_for: str -) -> lu.TracingDebugInfo: - # TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead - # We just have to make sure we grad the debugging information when we have - # the unflattened args - # TODO(necula): in general we can just pretend the leaves are booleans, but - # when we use custom pytrees, the flattening functions may check the type - # of the argument - try: - dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore - except: - # TODO(necula): remove this catch-all. Repro in batching_test:test_basic_jit - dummy_args = ([False], {}) if has_kwargs else [False] - args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) # type: ignore - def res_paths_thunk() -> tuple[str, ...]: - out_tree = out_tree_thunk() - dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves) - return tuple(tree_util.keystr(path) - for path, _ in tree_util.generate_key_paths(dummy_result)) - return api_util.tracing_debug_info(traced_for, fn, args, kwargs, - result_paths_thunk=res_paths_thunk) - -def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo | None: - fn_trees = flattened_fun_in_tree(fn) - if fn_trees is None: - # TODO(necula): eliminate this branch - return lu.TracingDebugInfo(traced_for, api_util.fun_sourceinfo(fn.f), - (None,), None) - in_tree, out_tree_thunk, has_kws = fn_trees - return tracing_debug_info(fn.f, in_tree, out_tree_thunk, has_kws, traced_for) @profiler.annotate_function def trace_to_jaxpr_dynamic( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7ca6f31d0..8d6cad4fe 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -652,6 +652,7 @@ class ParallelCallableInfo: in_axes: Iterable[int | None] out_axes_thunk: Callable[[], Sequence[int | None]] avals: Sequence[core.AbstractValue] + debug_info: api_util.TracingDebugInfo | None @cached_property def local_devices(self): @@ -722,8 +723,8 @@ def stage_parallel_callable( "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic( - fun, sharded_avals, pe.tracing_debug_info_final(fun, "pmap")) - jaxpr = api_util.add_jaxpr_debug_info(jaxpr, orig_fun.debug_info) + fun, sharded_avals, pci.debug_info) + jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info) assert len(out_sharded_avals) == len(pci.out_axes), ( len(out_sharded_avals), len(pci.out_axes)) @@ -757,7 +758,7 @@ def get_pmap_jaxpr( pci = ParallelCallableInfo( name, backend, axis_name, axis_size, global_axis_size, devices, - in_axes, out_axes_thunk, avals) + in_axes, out_axes_thunk, avals, fun.debug_info) with core.extend_axis_env_nd([(axis_name, axis_size)]): jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) @@ -880,8 +881,8 @@ def lower_parallel_callable( replicated_args=replicated_args, arg_shardings=None, result_shardings=None, - arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names, - result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths, + arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)), + result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)), num_replicas=replicas.num_global_replicas, lowering_parameters=lowering_parameters) return PmapComputation(lowering_result.module, @@ -1971,8 +1972,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, result_shardings=out_mlir_shardings, in_layouts=in_layouts, out_layouts=out_layouts, - arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names, - result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths, + arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)), + result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)), num_replicas=nreps, num_partitions=num_partitions, all_default_mem_kind=all_default_mem_kind, diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index e1f5efd70..ac267de85 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -71,7 +71,7 @@ import weakref from jax._src import config from jax._src import core from jax._src import traceback_util -from jax._src.util import curry, cache_clearing_funs +from jax._src.util import curry, cache_clearing_funs, HashableFunction traceback_util.register_exclusion(__file__) @@ -175,11 +175,11 @@ class WrappedFun: if out_store is None: return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args), ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None, None) + (out_store,) + self.stores, self.params, None, self.debug_info) else: return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args), ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None, None) + (out_store,) + self.stores, self.params, None, self.debug_info) def populate_stores(self, stores): """Copy the values from the `stores` into `self.stores`.""" @@ -282,6 +282,21 @@ class TracingDebugInfo(NamedTuple): jaxpr_dbg.arg_names, lambda: jaxpr_dbg.result_paths) + def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]] + ) -> TracingDebugInfo: + assert self.result_paths_thunk is None + return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk, + closure=())) + + def safe_arg_names(self, expected: int) -> tuple[str | None, ...]: + """Get the arg_names with a safety check.""" + if len(self.arg_names) == expected: + return self.arg_names + else: + # TODO(necula): this should not happen + return (None,) * expected + + def wrap_init(f: Callable, params=None, *, debug_info: TracingDebugInfo | None = None) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 1d709ffc4..c192168ce 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -569,6 +569,7 @@ def _infer_params_impl( f = lu.wrap_init(fun) f, res_paths = result_paths(f) + dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths) f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) del args @@ -1160,7 +1161,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals) if not config.dynamic_shapes.value and not attrs_tracked: pjit_check_aval_sharding(in_shardings_flat, in_avals, - None if debug_info is None else debug_info.arg_names, + None if debug_info is None else debug_info.safe_arg_names(len(in_avals)), "pjit arguments", allow_uneven_sharding=False) check_aval_layout_compatibility( in_layouts_flat, in_avals, @@ -1302,7 +1303,7 @@ def _create_pjit_jaxpr( in_type: core.InputType | Sequence[core.AbstractValue], attr_data: int, debug_info: lu.TracingDebugInfo, - out_paths: Callable, + result_paths: Callable, ignored_inline: IgnoreKey ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: @@ -1314,19 +1315,18 @@ def _create_pjit_jaxpr( with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): - pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for) if config.dynamic_shapes.value: jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( - lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug) + lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info) attrs_tracked = [] else: jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - fun, in_type, debug_info=pe_debug) + fun, in_type, debug_info=debug_info) # assert attr_data is sentinel or attr_data matches attrs_tracked # TODO(dougalm,mattjj): enable debug info with attrs_tracked if not config.dynamic_shapes.value and not attrs_tracked: - jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, out_paths()) + jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths()) if config.debug_key_reuse.value: # Import here to avoid circular imports @@ -1366,11 +1366,12 @@ def _check_and_canonicalize_out_shardings( if not config.dynamic_shapes.value: pjit_check_aval_sharding( out_shardings_flat, out_avals, - None if debug_info is None else debug_info.result_paths, + None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] "pjit outputs", allow_uneven_sharding=False) check_aval_layout_compatibility( out_layouts_flat, out_avals, - None if debug_info is None else debug_info.result_paths, "jit outputs") + None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] + "jit outputs") return out_shardings_flat, out_layouts_flat @@ -1423,9 +1424,9 @@ class IgnoreKey: def pjit_check_aval_sharding( - shardings, flat_avals, names: tuple[str, ...] | None, + shardings, flat_avals, names: tuple[str | None, ...] | None, what_aval: str, allow_uneven_sharding: bool): - new_names = [''] * len(shardings) if names is None else names + new_names = [None] * len(shardings) if names is None else names for aval, s, name in zip(flat_avals, shardings, new_names): if isinstance(s, (UnspecifiedValue, AUTO)): continue diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 333f6f3ff..39735ae5e 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -57,8 +57,6 @@ from jax._src.interpreters.partial_eval import ( dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule, dce_jaxpr_consts as dce_jaxpr_consts, dce_rules as dce_rules, - tracing_debug_info as tracing_debug_info, - tracing_debug_info_final as tracing_debug_info_final, def_trivial_padding as def_trivial_padding, forwarding_rules as forwarding_rules, has_effects as has_effects, diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 4060c63d7..55eaa4de4 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -639,8 +639,7 @@ class DebugInfoTest(jtu.JaxTestCase): check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_g, arg_names=b, from b", - # TODO(necula): bad arg name - "traced_for=jit, fun=my_f, arg_names=args[0], from args[0]" + "traced_for=jit, fun=my_f, arg_names=a, from a", ]) def test_jit_arg_names(self): @@ -692,8 +691,7 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, check_tracer_arg_name=True, expected_tracer_debug_infos=[ - # TODO(necula): the arg_names are not right, include static ones, also missing args[1] - "traced_for=jit, fun=my_f, arg_names=x['hi'],y,z,args[0],kwargs['t'],kwargs['w'], from kwargs['w']", + "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']", "None", # TODO(necula) ], expected_lowering_lines=[ @@ -1177,8 +1175,7 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, check_tracer_arg_name=True, expected_tracer_debug_infos=[ - # TODO(necula): bad tracer provenance - "traced_for=pmap, fun=my_f, arg_names=x,y,args[0],a,kwargs['b'],kwargs['d'], from args[0]", + "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", ], ) @@ -1219,6 +1216,31 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ # TODO(necula): why this? re.compile(r'traced_for=jit, fun=_multi_slice at .*/array_methods.py:.*, arg_names=self, result_paths=.*'), + "traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=['u'],['v']", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + # TODO(necula): missing debug_info + 'None' + ], + ) + + @jtu.ignore_warning(category=UserWarning, + message=".* jitted function .* includes a pmap") + def test_jvp_pmap(self): + tracer_spy = TracerSpy() + def my_f(x, y): + tracer_spy.append(x) + return jnp.sin(x) + y + + x = np.ones((jax.device_count(), 1), dtype=np.float32) + x_tan = np.full_like(x, .1) + + self._check_tracers_and_jaxprs( + jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))), + x, x_tan, + expected_jaxpr_debug_infos=[ + 'traced_for=jit, fun=, arg_names=x,x_tan, result_paths=[0],[1]', "None", # TODO(necula): missing debug info ], tracer_spy=tracer_spy, @@ -1307,7 +1329,7 @@ class DebugInfoTest(jtu.JaxTestCase): "None", # TODO(necula): missing tracer debug info ], expected_tracer_debug_infos=[ - "traced_for=xla_pmap, fun=my_f, arg_names=my_x", + "traced_for=pmap, fun=my_f, arg_names=my_x", ], check_lowering=False, # TODO(necula): warning during lowering ) @@ -1339,7 +1361,8 @@ class DebugInfoTest(jtu.JaxTestCase): # TODO(necula): some Jaxprs without debug info 'None'], expected_tracer_debug_infos=[ - "traced_for=custom_dce, fun=my_g, arg_names=x" + # TODO(necula): no leaked tracer from my_g_dce? + "traced_for=custom_dce, fun=my_g, arg_names=x", ]) def test_custom_dce_consts(self): @@ -1350,7 +1373,7 @@ class DebugInfoTest(jtu.JaxTestCase): return np.eye(1) * jnp.sin(x), jnp.cos(x) @my_f.def_dce - def rule(used_outs, y): + def my_rule(used_outs, y): tracer_spy.append(y) return ( np.full((1, 1), 2.0) * jnp.exp(y) if used_outs[0] else None, @@ -1367,8 +1390,8 @@ class DebugInfoTest(jtu.JaxTestCase): 'None'], check_tracer_arg_name=True, expected_tracer_debug_infos=[ + # TODO(necula): no leaked tracer from my_rule? "traced_for=custom_dce, fun=my_f, arg_names=x, from x", - # TODO(necula): rule.y does not have an inspected tracer? ]) def test_custom_linear_solve_complex(self):