[better_errors] More cleanup

This commit is contained in:
George Necula 2025-01-14 13:04:16 +00:00
parent f270739f9f
commit f9dfe7f646
18 changed files with 63 additions and 60 deletions

View File

@ -449,7 +449,7 @@ def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
jaxpr = jaxpr_.jaxpr
out_tree = lambda: tree_structure(out_shape)
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals")
dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals")
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:

View File

@ -653,7 +653,8 @@ def result_paths(_fun, _store, *args, **kwargs):
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""

View File

@ -1203,7 +1203,7 @@ def checkify(f: Callable[..., Out],
closed_f = lambda: f(*args, **kwargs)
# stage:
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify')
debug = pe.tracing_debug_info(closed_f, in_tree, out_tree, False, 'checkify')
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:

View File

@ -150,7 +150,7 @@ class custom_vmap:
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))

View File

@ -488,7 +488,7 @@ class custom_partitioning:
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False,
"custom_partitioning")
mesh = mesh_lib.thread_resources.env.physical_mesh
with core.extend_axis_env_nd(mesh.shape.items()):

View File

@ -1907,7 +1907,7 @@ class DynamicJaxprTrace(core.Trace):
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers])
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
dbg = debug_info_final(f, call_primitive.name)
dbg = tracing_debug_info_final(f, call_primitive.name)
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f, debug_info=dbg)
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
@ -1944,7 +1944,7 @@ class DynamicJaxprTrace(core.Trace):
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
f, reduced_in_avals,
debug_info=debug_info_final(f, map_primitive.name))
debug_info=tracing_debug_info_final(f, map_primitive.name))
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
@ -2106,14 +2106,15 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
return [*out_primals, *out_nz_tangents]
# Callers should be using linear_util.debug_info instead!
def debug_info(
def tracing_debug_info(
fn: Callable,
in_tree: PyTreeDef | None,
out_tree_thunk: Callable[[], PyTreeDef] | None,
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo | None:
) -> lu.TracingDebugInfo:
src_info = fun_sourceinfo(fn)
arg_names: tuple[str | None, ...] | None
try:
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
args, kwargs = dummy_args if has_kwargs else (dummy_args, {})
@ -2121,19 +2122,20 @@ def debug_info(
arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items()
for path, _ in generate_key_paths(dummy))
except:
arg_names = None
arg_names = None # TODO(necula): we should not need this
def result_paths():
try:
out_tree = out_tree_thunk()
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
except:
return None
return None # TODO(necula): this does not seem to be needed
return tuple(path for path, _ in generate_key_paths(dummy_result))
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore
# TODO(necula): clean up the type: ignore below
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore[arg-type]
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo | None:
def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo:
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
return tracing_debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
@profiler.annotate_function

View File

@ -721,8 +721,8 @@ def stage_parallel_callable(
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info)
fun, sharded_avals, pe.tracing_debug_info_final(fun, "pmap"))
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, orig_fun.debug_info)
assert len(out_sharded_avals) == len(pci.out_axes), (
len(out_sharded_avals), len(pci.out_axes))
@ -879,8 +879,8 @@ def lower_parallel_callable(
replicated_args=replicated_args,
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
num_replicas=replicas.num_global_replicas,
lowering_parameters=lowering_parameters)
return PmapComputation(lowering_result.module,
@ -891,7 +891,7 @@ def lower_parallel_callable(
ordered_effects=ordered_effects,
keepalive=lowering_result.keepalive,
host_callbacks=lowering_result.host_callbacks,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
jaxpr_debug_info=closed_jaxpr.jaxpr._debug_info,
shape_poly_state=lowering_result.shape_poly_state)
@ -1833,7 +1833,7 @@ def _move_mutable_consts(
def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
from jax._src.state.discharge import discharge_state
jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts)
jaxpr_._debug_info = jaxpr.jaxpr.debug_info
jaxpr_._debug_info = jaxpr.jaxpr._debug_info
return core.ClosedJaxpr(jaxpr_, consts)
@ -1971,8 +1971,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
result_shardings=out_mlir_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
num_replicas=nreps,
num_partitions=num_partitions,
all_default_mem_kind=all_default_mem_kind,
@ -2208,7 +2208,7 @@ def lower_sharding_computation(
auto_spmd_lowering = check_if_any_auto(
it.chain.from_iterable([in_shardings, out_shardings]))
all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr.debug_info)
all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr._debug_info)
closed_jaxpr, donated_invars, kept_var_idx, name_stack = _dce_jaxpr(
closed_jaxpr, api_name, fun_name, keep_unused, donated_invars,

View File

@ -53,7 +53,7 @@ def _typecheck_param(prim, param, name, msg_required, pred):
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: str | None = None):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, out_tree, False,
debug = pe.tracing_debug_info(fun, in_tree, out_tree, False,
primitive_name or "<unknown>")
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
wrapped_fun, in_avals, debug)

View File

@ -139,7 +139,7 @@ def switch(index, branches: Sequence[Callable], *operands,
ops_avals = tuple(map(core.get_aval, ops))
if config.mutable_array_checks.value:
dbg = pe.debug_info(branches[0], ops_tree, None, False, 'switch')
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch')
_check_no_aliased_ref_args(dbg, ops_avals, ops)
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
@ -238,7 +238,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
ops_avals = tuple(map(core.get_aval, ops))
if config.mutable_array_checks.value:
dbg = pe.debug_info(true_fun, ops_tree, None, False, 'cond')
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond')
_check_no_aliased_ref_args(dbg, ops_avals, ops)
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')

View File

@ -275,7 +275,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
if config.mutable_array_checks.value:
in_flat, in_tree = tree_flatten((init, xs))
dbg = pe.debug_info(f, in_tree, None, False, 'scan')
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan')
in_avals = tuple(_map(core.get_aval, in_flat))
_check_no_aliased_ref_args(dbg, in_avals, in_flat)

View File

@ -657,7 +657,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
@weakref_lru_cache
def _trace_composite_to_jaxpr(fun, in_tree, in_avals):
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
debug_info = pe.debug_info(fun, in_tree, out_tree, False, "composite")
debug_info = pe.tracing_debug_info(fun, in_tree, out_tree, False, "composite")
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
# TODO(danfm): support const inputs to composite.
assert not consts

View File

@ -258,7 +258,7 @@ class TracingDebugInfo(NamedTuple):
# formed just before staging to a jaxpr and read in trace-time error messages.
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: tuple[str, ...] # e.g. ('args[0]', ... )
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
# e.g. ('[0]', '[1]', ...)
result_paths_thunk: Callable[[], tuple[str, ...]] | None

View File

@ -421,7 +421,7 @@ class BlockSpec:
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func), index_map_tree
)
debug = pe.debug_info(
debug = pe.tracing_debug_info(
index_map_func,
index_map_tree,
index_map_out_tree_thunk,

View File

@ -1340,7 +1340,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
debug = pe.debug_info(
debug = pe.tracing_debug_info(
checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas")
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
@ -1410,7 +1410,7 @@ def _trace_kernel_to_jaxpr(
wrapped_kernel_fun = primitives.wrap_with_transforms(
wrapped_kernel_fun, kernel_in_transforms
)
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
debug = pe.tracing_debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)

View File

@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info,
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
@ -537,7 +537,7 @@ class PjitParams(NamedTuple):
in_tree: PyTreeDef
out_tree: PyTreeDef
donated_invars: tuple[bool, ...]
arg_names: tuple[str, ...] | None
arg_names: tuple[str | None, ...] | None
num_consts: int
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
abstract_mesh: AbstractMesh
@ -637,7 +637,7 @@ def _infer_params_impl(
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
ji.out_layouts_leaves, HashableFunction(out_tree, closure=()),
tuple(out_avals), jaxpr.jaxpr.debug_info, device_or_backend_set)
tuple(out_avals), jaxpr.jaxpr._debug_info, device_or_backend_set)
assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat)
@ -1301,7 +1301,7 @@ def _create_pjit_jaxpr(
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for)
if config.dynamic_shapes.value:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug)
@ -1313,7 +1313,7 @@ def _create_pjit_jaxpr(
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
if not config.dynamic_shapes.value and not attrs_tracked:
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, out_paths())
if config.debug_key_reuse.value:
# Import here to avoid circular imports
@ -1672,7 +1672,7 @@ def _pjit_call_impl_python(
if compiled._auto_spmd_lowering and config.enable_checks.value:
pxla.check_array_xla_sharding_layout_match(
args, compiled._in_shardings, compiled._in_layouts,
jaxpr.jaxpr.debug_info, compiled._kept_var_idx)
jaxpr.jaxpr.tracing_debug_info, compiled._kept_var_idx)
if config.distributed_debug.value:
# Defensively only perform fingerprint logic if debug logging is enabled
# NOTE(skyewm): I didn't benchmark this

View File

@ -993,7 +993,7 @@ def initial_style_jaxpr(
def _initial_style_jaxpr(fun, in_tree, in_avals):
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
tree_util.treedef_tuple((in_tree,)))
debug = pe.debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
debug = pe.tracing_debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
return jaxpr, consts, out_tree_thunk()

View File

@ -57,8 +57,8 @@ from jax._src.interpreters.partial_eval import (
dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule,
dce_jaxpr_consts as dce_jaxpr_consts,
dce_rules as dce_rules,
debug_info as debug_info,
debug_info_final as debug_info_final,
tracing_debug_info as tracing_debug_info,
tracing_debug_info_final as tracing_debug_info_final,
def_trivial_padding as def_trivial_padding,
forwarding_rules as forwarding_rules,
has_effects as has_effects,