mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[better_errors] More cleanup
This commit is contained in:
parent
f270739f9f
commit
f9dfe7f646
@ -449,7 +449,7 @@ def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
|
||||
jaxpr = jaxpr_.jaxpr
|
||||
out_tree = lambda: tree_structure(out_shape)
|
||||
assert len(jaxpr.invars) == len(in_leaves)
|
||||
dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals")
|
||||
dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals")
|
||||
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
|
||||
|
||||
def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
|
||||
|
@ -653,7 +653,8 @@ def result_paths(_fun, _store, *args, **kwargs):
|
||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
return ans
|
||||
|
||||
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
|
||||
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
|
||||
trace_debug: TracingDebugInfo | None,
|
||||
result_paths: tuple[str, ...] | None = None,
|
||||
) -> core.Jaxpr:
|
||||
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
|
||||
|
@ -1203,7 +1203,7 @@ def checkify(f: Callable[..., Out],
|
||||
closed_f = lambda: f(*args, **kwargs)
|
||||
# stage:
|
||||
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
|
||||
debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify')
|
||||
debug = pe.tracing_debug_info(closed_f, in_tree, out_tree, False, 'checkify')
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
|
||||
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
||||
# checkify:
|
||||
|
@ -150,7 +150,7 @@ class custom_vmap:
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
|
||||
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
||||
|
@ -488,7 +488,7 @@ class custom_partitioning:
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
|
||||
debug = pe.tracing_debug_info(self.fun, in_tree, out_tree, False,
|
||||
"custom_partitioning")
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
|
@ -1907,7 +1907,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
|
||||
in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers])
|
||||
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
|
||||
dbg = debug_info_final(f, call_primitive.name)
|
||||
dbg = tracing_debug_info_final(f, call_primitive.name)
|
||||
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f, debug_info=dbg)
|
||||
if params.get('inline', False):
|
||||
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
|
||||
@ -1944,7 +1944,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
|
||||
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
|
||||
f, reduced_in_avals,
|
||||
debug_info=debug_info_final(f, map_primitive.name))
|
||||
debug_info=tracing_debug_info_final(f, map_primitive.name))
|
||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||
if ordered_effects:
|
||||
raise ValueError("Ordered effects not supported for "
|
||||
@ -2106,14 +2106,15 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
|
||||
return [*out_primals, *out_nz_tangents]
|
||||
|
||||
# Callers should be using linear_util.debug_info instead!
|
||||
def debug_info(
|
||||
def tracing_debug_info(
|
||||
fn: Callable,
|
||||
in_tree: PyTreeDef | None,
|
||||
out_tree_thunk: Callable[[], PyTreeDef] | None,
|
||||
has_kwargs: bool,
|
||||
traced_for: str
|
||||
) -> lu.TracingDebugInfo | None:
|
||||
) -> lu.TracingDebugInfo:
|
||||
src_info = fun_sourceinfo(fn)
|
||||
arg_names: tuple[str | None, ...] | None
|
||||
try:
|
||||
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
|
||||
args, kwargs = dummy_args if has_kwargs else (dummy_args, {})
|
||||
@ -2121,19 +2122,20 @@ def debug_info(
|
||||
arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items()
|
||||
for path, _ in generate_key_paths(dummy))
|
||||
except:
|
||||
arg_names = None
|
||||
arg_names = None # TODO(necula): we should not need this
|
||||
def result_paths():
|
||||
try:
|
||||
out_tree = out_tree_thunk()
|
||||
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
|
||||
except:
|
||||
return None
|
||||
return None # TODO(necula): this does not seem to be needed
|
||||
return tuple(path for path, _ in generate_key_paths(dummy_result))
|
||||
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore
|
||||
# TODO(necula): clean up the type: ignore below
|
||||
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore[arg-type]
|
||||
|
||||
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo | None:
|
||||
def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo:
|
||||
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
|
||||
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
|
||||
return tracing_debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
|
@ -721,8 +721,8 @@ def stage_parallel_callable(
|
||||
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
|
||||
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info)
|
||||
fun, sharded_avals, pe.tracing_debug_info_final(fun, "pmap"))
|
||||
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, orig_fun.debug_info)
|
||||
|
||||
assert len(out_sharded_avals) == len(pci.out_axes), (
|
||||
len(out_sharded_avals), len(pci.out_axes))
|
||||
@ -879,8 +879,8 @@ def lower_parallel_callable(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
|
||||
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
|
||||
num_replicas=replicas.num_global_replicas,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return PmapComputation(lowering_result.module,
|
||||
@ -891,7 +891,7 @@ def lower_parallel_callable(
|
||||
ordered_effects=ordered_effects,
|
||||
keepalive=lowering_result.keepalive,
|
||||
host_callbacks=lowering_result.host_callbacks,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr._debug_info,
|
||||
shape_poly_state=lowering_result.shape_poly_state)
|
||||
|
||||
|
||||
@ -1833,7 +1833,7 @@ def _move_mutable_consts(
|
||||
def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
|
||||
from jax._src.state.discharge import discharge_state
|
||||
jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts)
|
||||
jaxpr_._debug_info = jaxpr.jaxpr.debug_info
|
||||
jaxpr_._debug_info = jaxpr.jaxpr._debug_info
|
||||
return core.ClosedJaxpr(jaxpr_, consts)
|
||||
|
||||
|
||||
@ -1971,8 +1971,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
result_shardings=out_mlir_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
|
||||
result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
|
||||
num_replicas=nreps,
|
||||
num_partitions=num_partitions,
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
@ -2208,7 +2208,7 @@ def lower_sharding_computation(
|
||||
auto_spmd_lowering = check_if_any_auto(
|
||||
it.chain.from_iterable([in_shardings, out_shardings]))
|
||||
|
||||
all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr.debug_info)
|
||||
all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr._debug_info)
|
||||
|
||||
closed_jaxpr, donated_invars, kept_var_idx, name_stack = _dce_jaxpr(
|
||||
closed_jaxpr, api_name, fun_name, keep_unused, donated_invars,
|
||||
|
@ -53,7 +53,7 @@ def _typecheck_param(prim, param, name, msg_required, pred):
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
primitive_name: str | None = None):
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun, in_tree, out_tree, False,
|
||||
debug = pe.tracing_debug_info(fun, in_tree, out_tree, False,
|
||||
primitive_name or "<unknown>")
|
||||
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||
wrapped_fun, in_avals, debug)
|
||||
|
@ -139,7 +139,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
dbg = pe.debug_info(branches[0], ops_tree, None, False, 'switch')
|
||||
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch')
|
||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
@ -238,7 +238,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
dbg = pe.debug_info(true_fun, ops_tree, None, False, 'cond')
|
||||
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond')
|
||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
|
@ -275,7 +275,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
dbg = pe.debug_info(f, in_tree, None, False, 'scan')
|
||||
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan')
|
||||
in_avals = tuple(_map(core.get_aval, in_flat))
|
||||
_check_no_aliased_ref_args(dbg, in_avals, in_flat)
|
||||
|
||||
|
@ -657,7 +657,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
|
||||
@weakref_lru_cache
|
||||
def _trace_composite_to_jaxpr(fun, in_tree, in_avals):
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
debug_info = pe.debug_info(fun, in_tree, out_tree, False, "composite")
|
||||
debug_info = pe.tracing_debug_info(fun, in_tree, out_tree, False, "composite")
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
|
||||
# TODO(danfm): support const inputs to composite.
|
||||
assert not consts
|
||||
|
@ -258,7 +258,7 @@ class TracingDebugInfo(NamedTuple):
|
||||
# formed just before staging to a jaxpr and read in trace-time error messages.
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
||||
arg_names: tuple[str, ...] # e.g. ('args[0]', ... )
|
||||
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
|
||||
# e.g. ('[0]', '[1]', ...)
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]] | None
|
||||
|
||||
|
@ -421,7 +421,7 @@ class BlockSpec:
|
||||
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(index_map_func), index_map_tree
|
||||
)
|
||||
debug = pe.debug_info(
|
||||
debug = pe.tracing_debug_info(
|
||||
index_map_func,
|
||||
index_map_tree,
|
||||
index_map_out_tree_thunk,
|
||||
|
@ -1340,7 +1340,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
|
||||
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
|
||||
debug = pe.debug_info(
|
||||
debug = pe.tracing_debug_info(
|
||||
checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas")
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
@ -1410,7 +1410,7 @@ def _trace_kernel_to_jaxpr(
|
||||
wrapped_kernel_fun = primitives.wrap_with_transforms(
|
||||
wrapped_kernel_fun, kernel_in_transforms
|
||||
)
|
||||
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
|
||||
debug = pe.tracing_debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||
kernel_avals, debug)
|
||||
|
@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, check_callable, resolve_argnums,
|
||||
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
|
||||
argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info,
|
||||
hoist_obj_attrs, _check_no_aliased_ref_args,
|
||||
_check_no_aliased_closed_over_refs)
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -537,7 +537,7 @@ class PjitParams(NamedTuple):
|
||||
in_tree: PyTreeDef
|
||||
out_tree: PyTreeDef
|
||||
donated_invars: tuple[bool, ...]
|
||||
arg_names: tuple[str, ...] | None
|
||||
arg_names: tuple[str | None, ...] | None
|
||||
num_consts: int
|
||||
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
|
||||
abstract_mesh: AbstractMesh
|
||||
@ -637,7 +637,7 @@ def _infer_params_impl(
|
||||
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
||||
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
|
||||
ji.out_layouts_leaves, HashableFunction(out_tree, closure=()),
|
||||
tuple(out_avals), jaxpr.jaxpr.debug_info, device_or_backend_set)
|
||||
tuple(out_avals), jaxpr.jaxpr._debug_info, device_or_backend_set)
|
||||
|
||||
assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat)
|
||||
|
||||
@ -1301,7 +1301,7 @@ def _create_pjit_jaxpr(
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
|
||||
pe_debug = debug_info and pe.tracing_debug_info_final(fun, debug_info.traced_for)
|
||||
if config.dynamic_shapes.value:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
||||
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug)
|
||||
@ -1313,7 +1313,7 @@ def _create_pjit_jaxpr(
|
||||
|
||||
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
|
||||
if not config.dynamic_shapes.value and not attrs_tracked:
|
||||
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
||||
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
||||
|
||||
if config.debug_key_reuse.value:
|
||||
# Import here to avoid circular imports
|
||||
@ -1672,7 +1672,7 @@ def _pjit_call_impl_python(
|
||||
if compiled._auto_spmd_lowering and config.enable_checks.value:
|
||||
pxla.check_array_xla_sharding_layout_match(
|
||||
args, compiled._in_shardings, compiled._in_layouts,
|
||||
jaxpr.jaxpr.debug_info, compiled._kept_var_idx)
|
||||
jaxpr.jaxpr.tracing_debug_info, compiled._kept_var_idx)
|
||||
if config.distributed_debug.value:
|
||||
# Defensively only perform fingerprint logic if debug logging is enabled
|
||||
# NOTE(skyewm): I didn't benchmark this
|
||||
|
@ -993,7 +993,7 @@ def initial_style_jaxpr(
|
||||
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
||||
tree_util.treedef_tuple((in_tree,)))
|
||||
debug = pe.debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
|
||||
debug = pe.tracing_debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
|
@ -57,8 +57,8 @@ from jax._src.interpreters.partial_eval import (
|
||||
dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule,
|
||||
dce_jaxpr_consts as dce_jaxpr_consts,
|
||||
dce_rules as dce_rules,
|
||||
debug_info as debug_info,
|
||||
debug_info_final as debug_info_final,
|
||||
tracing_debug_info as tracing_debug_info,
|
||||
tracing_debug_info_final as tracing_debug_info_final,
|
||||
def_trivial_padding as def_trivial_padding,
|
||||
forwarding_rules as forwarding_rules,
|
||||
has_effects as has_effects,
|
||||
|
Loading…
x
Reference in New Issue
Block a user