[better_errors] Make it explicit that debug_info is not None.

Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
This commit is contained in:
George Necula 2025-02-13 22:06:18 -08:00 committed by jax authors
parent 60dcded2af
commit a0812cd57e
14 changed files with 113 additions and 108 deletions

View File

@ -37,6 +37,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
This package may safely be removed if it is present on your machine; JAX now
uses `libtpu` instead.
* Deprecations
* The internal function `linear_util.wrap_init` and the constructor
`core.Jaxpr` now must take a non-empty `core.DebugInfo` kwarg. For
a limited time, a `DeprecationWarning` is printed if
`jax.extend.linear_util.wrap_init` is used without debugging info.
A downstream effect of this several other internal functions need debug
info. This change does not affect public APIs.
See https://github.com/jax-ml/jax/issues/26480 for more detail.
## jax 0.5.0 (Jan 17, 2025)
As of this release, JAX now uses

View File

@ -620,7 +620,7 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
return None
def save_wrapped_fun_sourceinfo(wrapper: Callable,
wrapped: Callable | core.DebugInfo | None) -> None:
wrapped: Callable | core.DebugInfo) -> None:
# Prefer this to functools.wraps because it does not create a reference to
# the wrapped function.
if isinstance(wrapped, core.DebugInfo):
@ -628,7 +628,7 @@ def save_wrapped_fun_sourceinfo(wrapper: Callable,
elif callable(wrapped):
func_src_info = fun_sourceinfo(wrapped)
else:
return
assert False, wrapped # Unreachable
setattr(wrapper, "__fun_sourceinfo__", func_src_info)
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
@ -716,7 +716,7 @@ def register_class_with_attrs(t: type) -> None:
_class_with_attrs: set[type] = set()
# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
def _check_no_aliased_ref_args(dbg: core.DebugInfo, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
@ -730,7 +730,7 @@ def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}

View File

@ -94,7 +94,7 @@ class Jaxpr:
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: DebugInfo | None
_debug_info: DebugInfo
@property
def constvars(self) -> list[Var]:
@ -117,13 +117,17 @@ class Jaxpr:
return self._effects
@property
def debug_info(self) -> DebugInfo | None:
def debug_info(self) -> DebugInfo:
return self._debug_info
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
debug_info: DebugInfo | None = None):
# We want all calls to pass a DebugInfo object, but for backwards
# compatibility we have to allow calls when the debug_info
# is missing.
debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment]
):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
@ -134,14 +138,16 @@ class Jaxpr:
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
debug_info: optional DebugInfo.
debug_info: debugging information.
"""
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info and debug_info.resolve_result_paths()
# TODO(https://github.com/jax-ml/jax/issues/26480)
debug_info = debug_info or lu._missing_debug_info("core.Jaxpr")
self._debug_info = debug_info.resolve_result_paths()
# TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)

View File

@ -30,8 +30,8 @@ from jax._src import traceback_util
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_non_static_arg_names, prepend_static_args, debug_info)
argnums_partial, flatten_fun_nokwargs, resolve_kwargs,
prepend_static_args, debug_info)
from jax._src.errors import UnexpectedTracerError
from jax._src.state.types import AbstractRef
from jax._src.interpreters import ad
@ -686,7 +686,7 @@ class custom_vjp(Generic[ReturnValue]):
@lu.transformation2
def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
debug_info: core.DebugInfo | None, *args):
debug_info: core.DebugInfo, *args):
_check_for_aliased_refs(f, nondiff_argnums, debug_info, args)
out = f(*args)
_check_for_returned_refs(f, out, 'primal')
@ -694,20 +694,14 @@ def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
def _check_for_aliased_refs(f: Callable,
nondiff_argnums: Sequence[int],
debug: core.DebugInfo | None,
debug: core.DebugInfo,
args):
leaves = tree_leaves(args)
refs: dict[int, int] = {}
for i, x in enumerate(leaves):
if (isinstance((a := core.get_aval(x)), AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
if debug is not None:
arg_names = debug.safe_arg_names(len(leaves))
else:
# TODO(necula): drop this branch
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
if arg_names is None:
arg_names = [f'flat index {j}' for j in range(len(leaves))]
arg_names = debug.safe_arg_names(len(leaves))
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but custom_vjp function {f} got the same mutable "
@ -763,8 +757,8 @@ def _check_for_tracers(x):
def _flatten_fwd(f: Callable, store: lu.EqualStore,
nondiff_argnums: Sequence[int],
symbolic_zeros: bool,
debug_primal: core.DebugInfo | None,
debug_fwd: core.DebugInfo | None,
debug_primal: core.DebugInfo,
debug_fwd: core.DebugInfo,
in_tree: PyTreeDef, maybe_out_type, *args):
primal_name = debug_primal.func_name if debug_primal else str(f)
fwd_name = debug_fwd.func_name if debug_fwd else "<unknown>"
@ -1560,9 +1554,9 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")
# simpler, but it would be worth revisiting this.
def optimize_remat_of_custom_vjp_fwd(
fun: Callable[..., ReturnValue],
debug_fun: core.DebugInfo | None,
debug_fun: core.DebugInfo,
fwd: Callable[..., tuple[ReturnValue, Any]],
debug_fwd: core.DebugInfo | None,
debug_fwd: core.DebugInfo,
nondiff_argnums: Sequence[int] = (),
symbolic_zeros: bool = False,
) -> Callable[..., tuple[ReturnValue, Any]]:

View File

@ -86,7 +86,7 @@ def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
@lu.transformation_with_aux2
def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
nzs_in: Sequence[bool],
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
*primals, **params):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace(debug_info)
@ -133,7 +133,7 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
return out_primals, out_tangents
def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
dbg = jaxpr.debug_info._replace(
arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars))
return core.Jaxpr(constvars=(),
invars=jaxpr.invars + jaxpr.constvars,
@ -768,7 +768,7 @@ def linearize_from_jvp(jvp: Callable,
multiple_results: bool,
nonzeros: Sequence[bool],
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
primals, params):
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
@ -1100,15 +1100,14 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
new_debug_info = jaxpr.jaxpr.debug_info
if new_debug_info is not None:
new_arg_names = tuple(_perm(primals_in, tangents_in,
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
new_result_paths = tuple(_perm(primals_out, tangents_out,
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
new_debug_info = new_debug_info._replace(
arg_names=new_arg_names,
result_paths=new_result_paths,
)
new_arg_names = tuple(_perm(primals_in, tangents_in,
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
new_result_paths = tuple(_perm(primals_out, tangents_out,
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
new_debug_info = new_debug_info._replace(
arg_names=new_arg_names,
result_paths=new_result_paths,
)
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
new_invars, new_outvars, jaxpr.jaxpr.eqns,
jaxpr.jaxpr.effects,

View File

@ -502,7 +502,8 @@ def _closed_call_param_updater(params, _, __):
return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
call_param_updaters[core.closed_call_p] = _closed_call_param_updater
def abstract_eval_fun(fun: Callable, *avals, debug_info=None, **params):
def abstract_eval_fun(fun: Callable, *avals,
debug_info: core.DebugInfo, **params):
_, avals_out, _, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params, debug_info=debug_info), avals)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
@ -582,7 +583,7 @@ def trace_to_subjaxpr_nounits(
f: Callable,
trace: JaxprTrace,
instantiate: Sequence[bool] | bool,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
@ -595,7 +596,7 @@ def trace_to_subjaxpr_nounits(
def trace_to_subjaxpr_nounits2(
f: Callable,
tag: TraceTag,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert isinstance(tag, TraceTag)
@ -612,7 +613,7 @@ def trace_to_subjaxpr_nounits2(
def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
instantiate: Sequence[bool] | bool,
in_pvals: Sequence[PartialVal],
debug_info: core.DebugInfo | None):
debug_info: core.DebugInfo):
in_knowns = [pval.is_known() for pval in in_pvals]
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
@ -639,7 +640,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
def trace_to_subjaxpr_nounits_fwd(
f: Callable,
tag: TraceTag,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
@ -669,7 +670,7 @@ def trace_to_subjaxpr_nounits_fwd(
def trace_to_subjaxpr_nounits_fwd2(
f: Callable,
tag: TraceTag,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
@ -752,13 +753,14 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer],
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
"""Constructs Jaxpr given tracers for inputs and outputs.
Params:
in_tracers: the tracers that were created for the function inputs
out_tracers: the tracers that were output by the function.
debug_info: the debug info for the function.
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
the `constvars` in the returned Jaxps, and a list of environment values.
@ -838,7 +840,7 @@ def tracers_to_jaxpr(
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
config.enable_checks.value and core.check_jaxpr(jaxpr)
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
dbg = jaxpr.debug_info._replace(
arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
@ -854,7 +856,7 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
return jaxpr.replace() # 'return jaxpr' would create cache reference cycle
config.enable_checks.value and core.check_jaxpr(jaxpr)
constvars, invars = split_list(jaxpr.invars, [n])
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
dbg = jaxpr.debug_info._replace(
arg_names=jaxpr.debug_info.arg_names[n:])
lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars,
debug_info=dbg)
@ -868,7 +870,7 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
config.enable_checks.value and core.check_jaxpr(converted_jaxpr)
return converted_jaxpr
@ -1363,7 +1365,7 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
dbg = jaxpr.debug_info and core.DebugInfo(
dbg = core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names,
jaxpr.debug_info.filter_result_paths(used_outputs))
@ -1451,7 +1453,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
eqns = new_eqns[::-1]
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
dbg = jaxpr.debug_info and core.DebugInfo(
dbg = core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.filter_arg_names(used_inputs),
jaxpr.debug_info.filter_result_paths(used_outputs))
@ -1653,9 +1655,9 @@ class JaxprStackFrame:
attrs_tracked: list[tuple[Any, str]]
attrs_inits: list
attrs_vars: list[Var]
debug_info: core.DebugInfo | None
debug_info: core.DebugInfo
def __init__(self, debug_info: core.DebugInfo | None):
def __init__(self, debug_info: core.DebugInfo):
self.gensym = core.gensym()
self.tracer_to_var = {}
self.constid_to_tracer = {}
@ -1674,7 +1676,7 @@ class JaxprStackFrame:
def to_jaxpr(self, trace: DynamicJaxprTrace,
out_tracers: Sequence[Tracer],
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
@ -1696,7 +1698,7 @@ class JaxprStackFrame:
return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)
def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
debug_info: core.DebugInfo | None):
debug_info: core.DebugInfo):
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
constvars, constvals = unzip2(self.constvar_to_val.items())
@ -1843,7 +1845,7 @@ def _inline_literals(
class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",)
def __init__(self, debug_info: core.DebugInfo | None):
def __init__(self, debug_info: core.DebugInfo):
self.frame = JaxprStackFrame(debug_info)
def invalidate(self):
@ -2117,7 +2119,7 @@ class DynamicJaxprTrace(core.Trace):
return out_tracers
def to_jaxpr(self, out_tracers: Sequence[Tracer],
debug_info: core.DebugInfo | None):
debug_info: core.DebugInfo):
return self.frame.to_jaxpr(self, out_tracers, debug_info)
@ -2180,17 +2182,13 @@ def trace_to_jaxpr_dynamic(
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
def _check_no_returned_refs(
dbg: core.DebugInfo | None,
dbg: core.DebugInfo,
out_tracers: Sequence[DynamicJaxprTracer]
) -> None:
if not config.mutable_array_checks.value: return
for i, t in enumerate(out_tracers):
a = t.aval
if isinstance(a, AbstractRef):
if dbg is None:
raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.")
result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
loc = result_paths[i] and f' at output tree path {result_paths[i]}'
frame = t._trace.frame
@ -2469,7 +2467,8 @@ def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
return aval
in_avals = [substitute(v.aval) for v in jaxpr.invars]
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts),
debug_info=jaxpr.debug_info)
padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals)
return padded_jaxpr, padded_consts

View File

@ -877,8 +877,8 @@ def lower_parallel_callable(
replicated_args=replicated_args,
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
num_replicas=replicas.num_global_replicas,
lowering_parameters=lowering_parameters)
return PmapComputation(lowering_result.module,
@ -1968,8 +1968,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
result_shardings=out_mlir_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
num_replicas=nreps,
num_partitions=num_partitions,
all_default_mem_kind=all_default_mem_kind,
@ -2125,7 +2125,7 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
class AllArgsInfo(NamedTuple):
"""Avals and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray]
debug_info: core.DebugInfo | None
debug_info: core.DebugInfo
@lru_cache(maxsize=2048)
@ -3202,17 +3202,13 @@ def cc_shard_arg(x, sharding, layout):
def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.DebugInfo | None = None):
jaxpr_debug_info: core.DebugInfo):
if len(ref_avals) != len(arg_avals):
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
f"but called with {len(arg_avals)}")
if jaxpr_debug_info is not None:
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
else:
num_args = len(ref_avals)
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
errors = []
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
@ -3264,14 +3260,13 @@ def check_array_xla_sharding_layout_match(
args_after_dce,
in_xla_shardings: Sequence[JSharding],
in_xla_layouts: Sequence[DeviceLocalLayout],
jaxpr_debug_info: core.DebugInfo | None,
jaxpr_debug_info: core.DebugInfo,
kept_var_idx: set[int]) -> None:
from jax._src.array import ArrayImpl
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
arg_names = (
[""] * len(args_after_dce) if jaxpr_debug_info is None
else [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore
if i in kept_var_idx]
[a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore
if i in kept_var_idx]
)
errors = []
num_errors = 5

View File

@ -73,7 +73,7 @@ for_p.skip_canonicalization = True
def _trace_to_jaxpr_with_refs(f: Callable, state_tree: PyTreeDef,
state_avals: Sequence[core.AbstractValue],
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
f, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(f, debug_info=debug_info),

View File

@ -336,7 +336,7 @@ def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
def _tangent_linear_map(func: Callable, params, params_dot,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
*x):
"""Compute the tangent of a linear map.

View File

@ -161,7 +161,7 @@ class WrappedFun:
f_transformed: Callable,
transforms,
stores: tuple[Store | EqualStore | None, ...], params, in_type,
debug_info: DebugInfo | None):
debug_info: DebugInfo):
self.f = f
self.f_transformed = f_transformed
self.transforms = transforms
@ -258,6 +258,7 @@ def fun_name(f):
except:
return str(f)
class DebugInfo(NamedTuple):
"""Debugging info about a func, its arguments, and results."""
traced_for: str # e.g. 'jit', 'scan', etc
@ -331,18 +332,17 @@ def _missing_debug_info(for_what: str) -> DebugInfo:
return DebugInfo("missing_debug_info", "<missing_debug_info>", (), ())
def wrap_init(f: Callable, params=None, *,
debug_info: DebugInfo | None = None) -> WrappedFun:
debug_info: DebugInfo) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params_dict = {} if params is None else params
params = () if params is None else tuple(sorted(params.items()))
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
if debug_info:
if debug_info.result_paths is None:
fun, result_paths_thunk = _get_result_paths_thunk(fun)
debug_info = debug_info._replace(
result_paths=HashableFunction(result_paths_thunk, closure=()))
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
fun.params, fun.in_type, debug_info)
if debug_info.result_paths is None:
fun, result_paths_thunk = _get_result_paths_thunk(fun)
debug_info = debug_info._replace(
result_paths=HashableFunction(result_paths_thunk, closure=()))
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
fun.params, fun.in_type, debug_info)
return fun

View File

@ -731,22 +731,20 @@ def _infer_params(
entry.pjit_params = p
return entry.pjit_params, entry.pjit_params.consts + dynargs
def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None,
def _infer_input_type(fun: Callable, dbg: core.DebugInfo,
explicit_args) -> tuple[core.AbstractValue, ...]:
avals = []
try:
for i, x in enumerate(explicit_args):
avals.append(core.shaped_abstractify(x))
except OverflowError:
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg
else f"flattened argument number is {i}")
arg_path = f"argument path is {dbg.arg_names[i]}" # type: ignore
raise OverflowError(
"An overflow was encountered while parsing an argument to a jitted "
f"computation, whose {arg_path}."
) from None
except TypeError:
arg_description = (f"path {dbg.arg_names[i]}" if dbg
else f"flattened argument number {i}")
arg_description = f"path {dbg.arg_names[i]}" # type: ignore
raise TypeError(
f"Error interpreting argument to {fun} as an abstract array."
f" The problematic value is of type {type(x)} and was passed to"
@ -1111,7 +1109,7 @@ class PytreeLeaf:
@util.cache(max_size=4096, trace_context_in_key=False)
def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
in_layouts_treedef, in_layouts_leaves,
in_avals, in_tree, debug_info,
in_avals, in_tree, debug_info: core.DebugInfo,
device_or_backend_set, kws):
if not kws:
in_tree, _ = treedef_children(in_tree)
@ -1136,11 +1134,11 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals)
if not config.dynamic_shapes.value and not attrs_tracked:
pjit_check_aval_sharding(in_shardings_flat, in_avals,
None if debug_info is None else debug_info.safe_arg_names(len(in_avals)),
debug_info.safe_arg_names(len(in_avals)),
"pjit arguments", allow_uneven_sharding=False)
check_aval_layout_compatibility(
in_layouts_flat, in_avals,
None if debug_info is None else debug_info.arg_names, "jit arguments")
debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type]
return in_shardings_flat, in_layouts_flat
callsites: set[str] = set()
@ -1167,7 +1165,7 @@ def explain_tracing_cache_miss(
# have we seen this function before at all?
fun_name = getattr(fun.f, '__qualname__', fun.f)
if debug_info is not None and debug_info.func_src_info:
if debug_info.func_src_info:
# TODO(necula): clean up the extraction of the source info
_, *rest = debug_info.func_src_info.split(' at ')
src_info = " defined at " + ' '.join(rest)
@ -1239,7 +1237,7 @@ def explain_tracing_cache_miss(
# have we never seen these input types (eg shapes, dtypes) before?
types_match = [k for k in trees_match if k[1] == in_type]
if not types_match:
if len(in_type) < 5 and debug_info is not None:
if len(in_type) < 5:
in_type_str = ':\n {}'.format(', '.join(
f'{n}: {ty.str_short(short_dtypes=True)}'
for n, ty in zip(debug_info.arg_names, in_type)))
@ -1251,10 +1249,7 @@ def explain_tracing_cache_miss(
num_mismatch = sum(map(op.ne, closest_ty, in_type))
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
add_weak_type_hint = False
if debug_info:
arg_names = debug_info.safe_arg_names(len(in_type))
else:
arg_names = (None,) * len(in_type)
arg_names = debug_info.safe_arg_names(len(in_type))
for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
if ty1 != ty2:
@ -1320,7 +1315,7 @@ def _create_pjit_jaxpr(
def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_avals,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):
@ -1340,11 +1335,11 @@ def _check_and_canonicalize_out_shardings(
if not config.dynamic_shapes.value:
pjit_check_aval_sharding(
out_shardings_flat, out_avals,
None if debug_info is None else debug_info.safe_result_paths(len(out_avals)),
debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
"pjit outputs", allow_uneven_sharding=False)
check_aval_layout_compatibility(
out_layouts_flat, out_avals,
None if debug_info is None else debug_info.safe_result_paths(len(out_avals)),
debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
"jit outputs")
return out_shardings_flat, out_layouts_flat

View File

@ -15,6 +15,8 @@
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from typing import Callable
from jax._src.linear_util import (
StoreException as StoreException,
WrappedFun as WrappedFun,
@ -24,7 +26,14 @@ from jax._src.linear_util import (
transformation_with_aux as transformation_with_aux,
transformation2 as transformation2,
transformation_with_aux2 as transformation_with_aux2,
wrap_init as wrap_init,
# TODO(b/396086979): remove this once we pass debug_info everywhere.
wrap_init as _wrap_init,
_missing_debug_info as _missing_debug_info,
)
# Version of wrap_init that does not require a DebugInfo object.
# This usage is deprecated, use api_util.debug_info() to construct a proper
# DebugInfo object.
def wrap_init(f: Callable, params=None, *, debug_info=None) -> WrappedFun:
debug_info = debug_info or _missing_debug_info("linear_util.wrap_init")
return _wrap_init(f, params, debug_info=debug_info)

View File

@ -71,8 +71,7 @@ def _collect_jaxprs(jaxpr: core.Jaxpr,
return acc
def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]:
if dbg is None: return "None"
def _debug_info_to_string(dbg: core.DebugInfo) -> list[str]:
# Strip the absolute path and the line number but check that it references
# this file (to catch errors when the source info points in JAX internals)
func_src_info = re.sub(r"^(\S+)( at .*.debug_info_test.py:\d+)?", "\\1", dbg.func_src_info)
@ -294,7 +293,6 @@ class DebugInfoTest(jtu.JaxTestCase):
def wrapper(x, y):
return x
api_util.save_wrapped_fun_sourceinfo(wrapper, None) # No effect
dbg = api_util.debug_info("test", wrapper, (1, 2), {})
self.assertEqual("wrapper", dbg.func_name)

View File

@ -53,7 +53,8 @@ class ExtendTest(jtu.JaxTestCase):
self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux)
self.assertIs(jex.linear_util.transformation, linear_util.transformation)
self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux)
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
# TODO(necula): revert this change once we deprecate the old wrap_init
# self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
class RandomTest(jtu.JaxTestCase):