From a0812cd57e57d6e4b9238a5d6671e9e4c4c440fe Mon Sep 17 00:00:00 2001
From: George Necula <necula@google.com>
Date: Thu, 13 Feb 2025 22:06:18 -0800
Subject: [PATCH] [better_errors] Make it explicit that debug_info is not None.

Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
---
 CHANGELOG.md                          |  9 +++++
 jax/_src/api_util.py                  |  8 ++---
 jax/_src/core.py                      | 16 ++++++---
 jax/_src/custom_derivatives.py        | 24 +++++---------
 jax/_src/interpreters/ad.py           | 23 +++++++------
 jax/_src/interpreters/partial_eval.py | 47 +++++++++++++--------------
 jax/_src/interpreters/pxla.py         | 25 ++++++--------
 jax/_src/lax/control_flow/for_loop.py |  2 +-
 jax/_src/lax/control_flow/solves.py   |  2 +-
 jax/_src/linear_util.py               | 18 +++++-----
 jax/_src/pjit.py                      | 29 +++++++----------
 jax/extend/linear_util.py             | 11 ++++++-
 tests/debug_info_test.py              |  4 +--
 tests/extend_test.py                  |  3 +-
 14 files changed, 113 insertions(+), 108 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index b29ad3f1a..11f93483d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -37,6 +37,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
     This package may safely be removed if it is present on your machine; JAX now
     uses `libtpu` instead.
 
+* Deprecations
+  * The internal function `linear_util.wrap_init` and the constructor
+    `core.Jaxpr` now must take a non-empty `core.DebugInfo` kwarg. For
+    a limited time, a `DeprecationWarning` is printed if
+    `jax.extend.linear_util.wrap_init` is used without debugging info.
+    A downstream effect of this several other internal functions need debug
+    info. This change does not affect public APIs.
+    See https://github.com/jax-ml/jax/issues/26480 for more detail.
+
 ## jax 0.5.0 (Jan 17, 2025)
 
 As of this release, JAX now uses
diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py
index 5cc31f53b..a597e8b5b 100644
--- a/jax/_src/api_util.py
+++ b/jax/_src/api_util.py
@@ -620,7 +620,7 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
     return None
 
 def save_wrapped_fun_sourceinfo(wrapper: Callable,
-                                wrapped: Callable | core.DebugInfo | None) -> None:
+                                wrapped: Callable | core.DebugInfo) -> None:
   # Prefer this to functools.wraps because it does not create a reference to
   # the wrapped function.
   if isinstance(wrapped, core.DebugInfo):
@@ -628,7 +628,7 @@ def save_wrapped_fun_sourceinfo(wrapper: Callable,
   elif callable(wrapped):
     func_src_info = fun_sourceinfo(wrapped)
   else:
-    return
+    assert False, wrapped  # Unreachable
   setattr(wrapper, "__fun_sourceinfo__", func_src_info)
 
 _fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
@@ -716,7 +716,7 @@ def register_class_with_attrs(t: type) -> None:
 _class_with_attrs: set[type] = set()
 
 # TODO(mattjj): make this function faster
-def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
+def _check_no_aliased_ref_args(dbg: core.DebugInfo, avals, args):
   assert config.mutable_array_checks.value
   refs: dict[int, int] = {}
   for i, (a, x) in enumerate(zip(avals, args)):
@@ -730,7 +730,7 @@ def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
         if dbg else
         f"at both flat index {dup_idx} and flat index {i}") from None
 
-def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
+def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> None:
   assert config.mutable_array_checks.value
   refs: set[int] = {id(core.get_referent(c)) for c in consts
                     if isinstance(core.get_aval(c), AbstractRef)}
diff --git a/jax/_src/core.py b/jax/_src/core.py
index f6a95e4de..65b0286eb 100644
--- a/jax/_src/core.py
+++ b/jax/_src/core.py
@@ -94,7 +94,7 @@ class Jaxpr:
   _outvars: list[Atom]
   _eqns: list[JaxprEqn]
   _effects: Effects
-  _debug_info: DebugInfo | None
+  _debug_info: DebugInfo
 
   @property
   def constvars(self) -> list[Var]:
@@ -117,13 +117,17 @@ class Jaxpr:
     return self._effects
 
   @property
-  def debug_info(self) -> DebugInfo | None:
+  def debug_info(self) -> DebugInfo:
     return self._debug_info
 
   def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
                outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
                effects: Effects = no_effects,
-               debug_info: DebugInfo | None = None):
+               # We want all calls to pass a DebugInfo object, but for backwards
+               # compatibility we have to allow calls when the debug_info
+               # is missing.
+               debug_info: DebugInfo = None,  # type: ignore[annotation-type-mismatch,assignment]
+               ):
     """
     Args:
       constvars: list of variables introduced for constants. Array constants are
@@ -134,14 +138,16 @@ class Jaxpr:
       eqns: list of equations.
       effects: set of effects. The effects on a jaxpr are a superset of the
         union of the effects for each equation.
-      debug_info: optional DebugInfo.
+      debug_info: debugging information.
     """
     self._constvars = list(constvars)
     self._invars = list(invars)
     self._outvars = list(outvars)
     self._eqns = list(eqns)
     self._effects = effects
-    self._debug_info = debug_info and debug_info.resolve_result_paths()
+    # TODO(https://github.com/jax-ml/jax/issues/26480)
+    debug_info = debug_info or lu._missing_debug_info("core.Jaxpr")
+    self._debug_info = debug_info.resolve_result_paths()
     # TODO(necula): re-enable these safety checks
     # assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
     # assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py
index 2c6c41ff5..1cea84110 100644
--- a/jax/_src/custom_derivatives.py
+++ b/jax/_src/custom_derivatives.py
@@ -30,8 +30,8 @@ from jax._src import traceback_util
 from jax._src.ad_util import (
     stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
 from jax._src.api_util import (
-  argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
-  _non_static_arg_names, prepend_static_args, debug_info)
+  argnums_partial, flatten_fun_nokwargs, resolve_kwargs,
+  prepend_static_args, debug_info)
 from jax._src.errors import UnexpectedTracerError
 from jax._src.state.types import AbstractRef
 from jax._src.interpreters import ad
@@ -686,7 +686,7 @@ class custom_vjp(Generic[ReturnValue]):
 
 @lu.transformation2
 def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
-                       debug_info: core.DebugInfo | None, *args):
+                       debug_info: core.DebugInfo, *args):
   _check_for_aliased_refs(f, nondiff_argnums, debug_info, args)
   out = f(*args)
   _check_for_returned_refs(f, out, 'primal')
@@ -694,20 +694,14 @@ def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
 
 def _check_for_aliased_refs(f: Callable,
                             nondiff_argnums: Sequence[int],
-                            debug: core.DebugInfo | None,
+                            debug: core.DebugInfo,
                             args):
   leaves = tree_leaves(args)
   refs: dict[int, int] = {}
   for i, x in enumerate(leaves):
     if (isinstance((a := core.get_aval(x)), AbstractRef) and
         (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
-      if debug is not None:
-        arg_names = debug.safe_arg_names(len(leaves))
-      else:
-        # TODO(necula): drop this branch
-        arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
-      if arg_names is None:
-        arg_names = [f'flat index {j}' for j in range(len(leaves))]
+      arg_names = debug.safe_arg_names(len(leaves))
       raise ValueError(
           "only one reference to a mutable array may be passed as an argument "
           f"to a function, but custom_vjp function {f} got the same mutable "
@@ -763,8 +757,8 @@ def _check_for_tracers(x):
 def _flatten_fwd(f: Callable, store: lu.EqualStore,
                  nondiff_argnums: Sequence[int],
                  symbolic_zeros: bool,
-                 debug_primal: core.DebugInfo | None,
-                 debug_fwd: core.DebugInfo | None,
+                 debug_primal: core.DebugInfo,
+                 debug_fwd: core.DebugInfo,
                  in_tree: PyTreeDef, maybe_out_type, *args):
   primal_name = debug_primal.func_name if debug_primal else str(f)
   fwd_name = debug_fwd.func_name if debug_fwd else "<unknown>"
@@ -1560,9 +1554,9 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")
 # simpler, but it would be worth revisiting this.
 def optimize_remat_of_custom_vjp_fwd(
     fun: Callable[..., ReturnValue],
-    debug_fun: core.DebugInfo | None,
+    debug_fun: core.DebugInfo,
     fwd: Callable[..., tuple[ReturnValue, Any]],
-    debug_fwd: core.DebugInfo | None,
+    debug_fwd: core.DebugInfo,
     nondiff_argnums: Sequence[int] = (),
     symbolic_zeros: bool = False,
 ) -> Callable[..., tuple[ReturnValue, Any]]:
diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py
index ec33fb02e..0131db6c9 100644
--- a/jax/_src/interpreters/ad.py
+++ b/jax/_src/interpreters/ad.py
@@ -86,7 +86,7 @@ def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
 @lu.transformation_with_aux2
 def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
                        nzs_in: Sequence[bool],
-                       debug_info: core.DebugInfo | None,
+                       debug_info: core.DebugInfo,
                        *primals, **params):
   with core.take_current_trace() as parent_trace:
     tangent_trace = pe.DynamicJaxprTrace(debug_info)
@@ -133,7 +133,7 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
   return out_primals, out_tangents
 
 def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
-  dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
+  dbg = jaxpr.debug_info._replace(
       arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars))
   return core.Jaxpr(constvars=(),
                     invars=jaxpr.invars + jaxpr.constvars,
@@ -768,7 +768,7 @@ def linearize_from_jvp(jvp: Callable,
                        multiple_results: bool,
                        nonzeros: Sequence[bool],
                        user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
-                       debug_info: core.DebugInfo | None,
+                       debug_info: core.DebugInfo,
                        primals, params):
   current_name_stack = source_info_util.current_name_stack()
   with core.take_current_trace() as parent_trace:
@@ -1100,15 +1100,14 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
   new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
   new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
   new_debug_info = jaxpr.jaxpr.debug_info
-  if new_debug_info is not None:
-    new_arg_names = tuple(_perm(primals_in, tangents_in,
-                                jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
-    new_result_paths = tuple(_perm(primals_out, tangents_out,
-                                   jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
-    new_debug_info = new_debug_info._replace(
-        arg_names=new_arg_names,
-        result_paths=new_result_paths,
-    )
+  new_arg_names = tuple(_perm(primals_in, tangents_in,
+                              jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
+  new_result_paths = tuple(_perm(primals_out, tangents_out,
+                                  jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
+  new_debug_info = new_debug_info._replace(
+      arg_names=new_arg_names,
+      result_paths=new_result_paths,
+  )
   new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
                          new_invars, new_outvars, jaxpr.jaxpr.eqns,
                          jaxpr.jaxpr.effects,
diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py
index 9f7934378..b77817d99 100644
--- a/jax/_src/interpreters/partial_eval.py
+++ b/jax/_src/interpreters/partial_eval.py
@@ -502,7 +502,8 @@ def _closed_call_param_updater(params, _, __):
   return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
 call_param_updaters[core.closed_call_p] = _closed_call_param_updater
 
-def abstract_eval_fun(fun: Callable, *avals, debug_info=None, **params):
+def abstract_eval_fun(fun: Callable, *avals,
+                      debug_info: core.DebugInfo, **params):
   _, avals_out, _, () = trace_to_jaxpr_dynamic(
       lu.wrap_init(fun, params, debug_info=debug_info), avals)
   assert all(isinstance(aval, AbstractValue) for aval in avals_out)
@@ -582,7 +583,7 @@ def trace_to_subjaxpr_nounits(
     f: Callable,
     trace: JaxprTrace,
     instantiate: Sequence[bool] | bool,
-    debug_info: core.DebugInfo | None,
+    debug_info: core.DebugInfo,
     in_pvals: Sequence[PartialVal]):
   assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
   out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
@@ -595,7 +596,7 @@ def trace_to_subjaxpr_nounits(
 def trace_to_subjaxpr_nounits2(
     f: Callable,
     tag: TraceTag,
-    debug_info: core.DebugInfo | None,
+    debug_info: core.DebugInfo,
     instantiate: bool | Sequence[bool],
     in_pvals: Sequence[PartialVal]):
   assert isinstance(tag, TraceTag)
@@ -612,7 +613,7 @@ def trace_to_subjaxpr_nounits2(
 def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
                                instantiate: Sequence[bool] | bool,
                                in_pvals: Sequence[PartialVal],
-                               debug_info: core.DebugInfo | None):
+                               debug_info: core.DebugInfo):
   in_knowns  = [pval.is_known()     for pval in in_pvals]
   in_consts  = [pval.get_known()    for pval in in_pvals if     pval.is_known()]
   in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
@@ -639,7 +640,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
 def trace_to_subjaxpr_nounits_fwd(
     f: Callable,
     tag: TraceTag,
-    debug_info: core.DebugInfo | None,
+    debug_info: core.DebugInfo,
     instantiate: bool | Sequence[bool],
     in_pvals: Sequence[PartialVal]):
   assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
@@ -669,7 +670,7 @@ def trace_to_subjaxpr_nounits_fwd(
 def trace_to_subjaxpr_nounits_fwd2(
     f: Callable,
     tag: TraceTag,
-    debug_info: core.DebugInfo | None,
+    debug_info: core.DebugInfo,
     instantiate: bool | Sequence[bool],
     in_pvals: Sequence[PartialVal]):
   assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
@@ -752,13 +753,14 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
 def tracers_to_jaxpr(
   in_tracers: Sequence[JaxprTracer],
   out_tracers: Sequence[JaxprTracer],
-  debug_info: core.DebugInfo | None,
+  debug_info: core.DebugInfo,
   ) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
   """Constructs Jaxpr given tracers for inputs and outputs.
 
   Params:
     in_tracers: the tracers that were created for the function inputs
     out_tracers: the tracers that were output by the function.
+    debug_info: the debug info for the function.
 
   Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
     the `constvars` in the returned Jaxps, and a list of environment values.
@@ -838,7 +840,7 @@ def tracers_to_jaxpr(
 def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
   """Moves the constvars to the start of invars."""
   config.enable_checks.value and core.check_jaxpr(jaxpr)
-  dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
+  dbg = jaxpr.debug_info._replace(
       arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
   lifted_jaxpr = Jaxpr(constvars=(),
                        invars=jaxpr.constvars + jaxpr.invars,
@@ -854,7 +856,7 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
     return jaxpr.replace()  # 'return jaxpr' would create cache reference cycle
   config.enable_checks.value and core.check_jaxpr(jaxpr)
   constvars, invars = split_list(jaxpr.invars, [n])
-  dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
+  dbg = jaxpr.debug_info._replace(
       arg_names=jaxpr.debug_info.arg_names[n:])
   lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars,
                                debug_info=dbg)
@@ -868,7 +870,7 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
   env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
   converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
                           invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
-                          effects=jaxpr.effects)
+                          effects=jaxpr.effects, debug_info=jaxpr.debug_info)
   config.enable_checks.value and core.check_jaxpr(converted_jaxpr)
   return converted_jaxpr
 
@@ -1363,7 +1365,7 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
 
 def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
   outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
-  dbg = jaxpr.debug_info and core.DebugInfo(
+  dbg = core.DebugInfo(
       jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
       jaxpr.debug_info.arg_names,
       jaxpr.debug_info.filter_result_paths(used_outputs))
@@ -1451,7 +1453,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
   eqns = new_eqns[::-1]
   jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
 
-  dbg = jaxpr.debug_info and core.DebugInfo(
+  dbg = core.DebugInfo(
       jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
       jaxpr.debug_info.filter_arg_names(used_inputs),
       jaxpr.debug_info.filter_result_paths(used_outputs))
@@ -1653,9 +1655,9 @@ class JaxprStackFrame:
   attrs_tracked: list[tuple[Any, str]]
   attrs_inits: list
   attrs_vars: list[Var]
-  debug_info: core.DebugInfo | None
+  debug_info: core.DebugInfo
 
-  def __init__(self, debug_info: core.DebugInfo | None):
+  def __init__(self, debug_info: core.DebugInfo):
     self.gensym = core.gensym()
     self.tracer_to_var = {}
     self.constid_to_tracer = {}
@@ -1674,7 +1676,7 @@ class JaxprStackFrame:
 
   def to_jaxpr(self, trace: DynamicJaxprTrace,
                out_tracers: Sequence[Tracer],
-               debug_info: core.DebugInfo | None,
+               debug_info: core.DebugInfo,
                ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
     # It's not necessary, but we keep the tracer-to-var mapping injective:
     assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
@@ -1696,7 +1698,7 @@ class JaxprStackFrame:
     return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)
 
   def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
-                debug_info: core.DebugInfo | None):
+                debug_info: core.DebugInfo):
     # It's not necessary, but we keep the tracer-to-var mapping injective:
     assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
     constvars, constvals = unzip2(self.constvar_to_val.items())
@@ -1843,7 +1845,7 @@ def _inline_literals(
 class DynamicJaxprTrace(core.Trace):
   __slots__ = ("frame",)
 
-  def __init__(self, debug_info: core.DebugInfo | None):
+  def __init__(self, debug_info: core.DebugInfo):
     self.frame = JaxprStackFrame(debug_info)
 
   def invalidate(self):
@@ -2117,7 +2119,7 @@ class DynamicJaxprTrace(core.Trace):
     return out_tracers
 
   def to_jaxpr(self, out_tracers: Sequence[Tracer],
-               debug_info: core.DebugInfo | None):
+               debug_info: core.DebugInfo):
     return self.frame.to_jaxpr(self, out_tracers, debug_info)
 
 
@@ -2180,17 +2182,13 @@ def trace_to_jaxpr_dynamic(
   return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
 
 def _check_no_returned_refs(
-    dbg: core.DebugInfo | None,
+    dbg: core.DebugInfo,
     out_tracers: Sequence[DynamicJaxprTracer]
 ) -> None:
   if not config.mutable_array_checks.value: return
   for i, t in enumerate(out_tracers):
     a = t.aval
     if isinstance(a, AbstractRef):
-      if dbg is None:
-        raise ValueError(
-        f"function returned a mutable array reference of type {a.str_short()}, "
-        "but mutable array references cannot be returned.")
       result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
       loc = result_paths[i] and f' at output tree path {result_paths[i]}'
       frame = t._trace.frame
@@ -2469,7 +2467,8 @@ def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
       return aval
 
   in_avals = [substitute(v.aval) for v in jaxpr.invars]
-  eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
+  eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts),
+                             debug_info=jaxpr.debug_info)
   padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals)
   return padded_jaxpr, padded_consts
 
diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py
index 7655f4ab5..eb6c6b940 100644
--- a/jax/_src/interpreters/pxla.py
+++ b/jax/_src/interpreters/pxla.py
@@ -877,8 +877,8 @@ def lower_parallel_callable(
           replicated_args=replicated_args,
           arg_shardings=None,
           result_shardings=None,
-          arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
-          result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
+          arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
+          result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
           num_replicas=replicas.num_global_replicas,
           lowering_parameters=lowering_parameters)
   return PmapComputation(lowering_result.module,
@@ -1968,8 +1968,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
         result_shardings=out_mlir_shardings,
         in_layouts=in_layouts,
         out_layouts=out_layouts,
-        arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
-        result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
+        arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
+        result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
         num_replicas=nreps,
         num_partitions=num_partitions,
         all_default_mem_kind=all_default_mem_kind,
@@ -2125,7 +2125,7 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
 class AllArgsInfo(NamedTuple):
   """Avals and debug_info for all arguments prior to DCE."""
   in_avals: Sequence[core.ShapedArray]
-  debug_info: core.DebugInfo | None
+  debug_info: core.DebugInfo
 
 
 @lru_cache(maxsize=2048)
@@ -3202,17 +3202,13 @@ def cc_shard_arg(x, sharding, layout):
 
 
 def check_arg_avals_for_call(ref_avals, arg_avals,
-                             jaxpr_debug_info: core.DebugInfo | None = None):
+                             jaxpr_debug_info: core.DebugInfo):
   if len(ref_avals) != len(arg_avals):
     raise TypeError(
         f"Computation compiled for {len(ref_avals)} inputs "
         f"but called with {len(arg_avals)}")
 
-  if jaxpr_debug_info is not None:
-    arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
-  else:
-    num_args = len(ref_avals)
-    arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
+  arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
 
   errors = []
   for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
@@ -3264,14 +3260,13 @@ def check_array_xla_sharding_layout_match(
     args_after_dce,
     in_xla_shardings: Sequence[JSharding],
     in_xla_layouts: Sequence[DeviceLocalLayout],
-    jaxpr_debug_info: core.DebugInfo | None,
+    jaxpr_debug_info: core.DebugInfo,
     kept_var_idx: set[int]) -> None:
   from jax._src.array import ArrayImpl
   # jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
   arg_names = (
-      [""] * len(args_after_dce) if jaxpr_debug_info is None
-      else [a for i, a in enumerate(jaxpr_debug_info.arg_names)  # type: ignore
-            if i in kept_var_idx]
+      [a for i, a in enumerate(jaxpr_debug_info.arg_names)  # type: ignore
+       if i in kept_var_idx]
   )
   errors = []
   num_errors = 5
diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py
index b6966cf18..fc7ebde4c 100644
--- a/jax/_src/lax/control_flow/for_loop.py
+++ b/jax/_src/lax/control_flow/for_loop.py
@@ -73,7 +73,7 @@ for_p.skip_canonicalization = True
 
 def _trace_to_jaxpr_with_refs(f: Callable, state_tree: PyTreeDef,
                               state_avals: Sequence[core.AbstractValue],
-                              debug_info: core.DebugInfo | None,
+                              debug_info: core.DebugInfo,
                               ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
   f, out_tree_thunk = api_util.flatten_fun_nokwargs(
       lu.wrap_init(f, debug_info=debug_info),
diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py
index e278e7ee9..acfcfd7ff 100644
--- a/jax/_src/lax/control_flow/solves.py
+++ b/jax/_src/lax/control_flow/solves.py
@@ -336,7 +336,7 @@ def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
 
 
 def _tangent_linear_map(func: Callable, params, params_dot,
-                        debug_info: core.DebugInfo | None,
+                        debug_info: core.DebugInfo,
                         *x):
   """Compute the tangent of a linear map.
 
diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py
index 565598e51..b58f2ee2b 100644
--- a/jax/_src/linear_util.py
+++ b/jax/_src/linear_util.py
@@ -161,7 +161,7 @@ class WrappedFun:
                f_transformed: Callable,
                transforms,
                stores: tuple[Store | EqualStore | None, ...], params, in_type,
-               debug_info: DebugInfo | None):
+               debug_info: DebugInfo):
     self.f = f
     self.f_transformed = f_transformed
     self.transforms = transforms
@@ -258,6 +258,7 @@ def fun_name(f):
   except:
     return str(f)
 
+
 class DebugInfo(NamedTuple):
   """Debugging info about a func, its arguments, and results."""
   traced_for: str             # e.g. 'jit', 'scan', etc
@@ -331,18 +332,17 @@ def _missing_debug_info(for_what: str) -> DebugInfo:
   return DebugInfo("missing_debug_info", "<missing_debug_info>", (), ())
 
 def wrap_init(f: Callable, params=None, *,
-              debug_info: DebugInfo | None = None) -> WrappedFun:
+              debug_info: DebugInfo) -> WrappedFun:
   """Wraps function `f` as a `WrappedFun`, suitable for transformation."""
   params_dict = {} if params is None else params
   params = () if params is None else tuple(sorted(params.items()))
   fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
-  if debug_info:
-    if debug_info.result_paths is None:
-      fun, result_paths_thunk = _get_result_paths_thunk(fun)
-      debug_info = debug_info._replace(
-          result_paths=HashableFunction(result_paths_thunk, closure=()))
-    fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
-                     fun.params, fun.in_type, debug_info)
+  if debug_info.result_paths is None:
+    fun, result_paths_thunk = _get_result_paths_thunk(fun)
+    debug_info = debug_info._replace(
+        result_paths=HashableFunction(result_paths_thunk, closure=()))
+  fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
+                    fun.params, fun.in_type, debug_info)
   return fun
 
 
diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py
index 4d6a0ac20..beb4be1cc 100644
--- a/jax/_src/pjit.py
+++ b/jax/_src/pjit.py
@@ -731,22 +731,20 @@ def _infer_params(
     entry.pjit_params = p
   return entry.pjit_params, entry.pjit_params.consts + dynargs
 
-def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None,
+def _infer_input_type(fun: Callable, dbg: core.DebugInfo,
                       explicit_args) -> tuple[core.AbstractValue, ...]:
   avals = []
   try:
     for i, x in enumerate(explicit_args):
       avals.append(core.shaped_abstractify(x))
   except OverflowError:
-    arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg
-                else f"flattened argument number is {i}")
+    arg_path = f"argument path is {dbg.arg_names[i]}"  # type: ignore
     raise OverflowError(
       "An overflow was encountered while parsing an argument to a jitted "
       f"computation, whose {arg_path}."
     ) from None
   except TypeError:
-    arg_description = (f"path {dbg.arg_names[i]}" if dbg
-                       else f"flattened argument number {i}")
+    arg_description = f"path {dbg.arg_names[i]}"  # type: ignore
     raise TypeError(
       f"Error interpreting argument to {fun} as an abstract array."
       f" The problematic value is of type {type(x)} and was passed to"
@@ -1111,7 +1109,7 @@ class PytreeLeaf:
 @util.cache(max_size=4096, trace_context_in_key=False)
 def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
                                in_layouts_treedef, in_layouts_leaves,
-                               in_avals, in_tree, debug_info,
+                               in_avals, in_tree, debug_info: core.DebugInfo,
                                device_or_backend_set, kws):
   if not kws:
     in_tree, _ = treedef_children(in_tree)
@@ -1136,11 +1134,11 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
   attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals)
   if not config.dynamic_shapes.value and not attrs_tracked:
     pjit_check_aval_sharding(in_shardings_flat, in_avals,
-                             None if debug_info is None else debug_info.safe_arg_names(len(in_avals)),
+                             debug_info.safe_arg_names(len(in_avals)),
                              "pjit arguments", allow_uneven_sharding=False)
     check_aval_layout_compatibility(
         in_layouts_flat, in_avals,
-        None if debug_info is None else debug_info.arg_names, "jit arguments")
+        debug_info.safe_arg_names(len(in_avals)), "jit arguments")  # type: ignore[arg-type]
   return in_shardings_flat, in_layouts_flat
 
 callsites: set[str] = set()
@@ -1167,7 +1165,7 @@ def explain_tracing_cache_miss(
 
   # have we seen this function before at all?
   fun_name = getattr(fun.f, '__qualname__', fun.f)
-  if debug_info is not None and debug_info.func_src_info:
+  if debug_info.func_src_info:
     # TODO(necula): clean up the extraction of the source info
     _, *rest = debug_info.func_src_info.split(' at ')
     src_info = " defined at "  + ' '.join(rest)
@@ -1239,7 +1237,7 @@ def explain_tracing_cache_miss(
   # have we never seen these input types (eg shapes, dtypes) before?
   types_match = [k for k in trees_match if k[1] == in_type]
   if not types_match:
-    if len(in_type) < 5 and debug_info is not None:
+    if len(in_type) < 5:
       in_type_str = ':\n    {}'.format(',  '.join(
           f'{n}: {ty.str_short(short_dtypes=True)}'
           for n, ty in zip(debug_info.arg_names, in_type)))
@@ -1251,10 +1249,7 @@ def explain_tracing_cache_miss(
     num_mismatch = sum(map(op.ne, closest_ty, in_type))
     p(f"  closest seen input type signature has {num_mismatch} mismatches, including:")
     add_weak_type_hint = False
-    if debug_info:
-      arg_names = debug_info.safe_arg_names(len(in_type))
-    else:
-      arg_names = (None,) * len(in_type)
+    arg_names = debug_info.safe_arg_names(len(in_type))
 
     for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
       if ty1 != ty2:
@@ -1320,7 +1315,7 @@ def _create_pjit_jaxpr(
 def _check_and_canonicalize_out_shardings(
     out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
     out_layouts_leaves, out_tree, out_avals,
-    debug_info: core.DebugInfo | None,
+    debug_info: core.DebugInfo,
     device_or_backend_set):
   orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
   if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):
@@ -1340,11 +1335,11 @@ def _check_and_canonicalize_out_shardings(
   if not config.dynamic_shapes.value:
     pjit_check_aval_sharding(
         out_shardings_flat, out_avals,
-        None if debug_info is None else debug_info.safe_result_paths(len(out_avals)),
+        debug_info.safe_result_paths(len(out_avals)),  # type: ignore[arg-type]
         "pjit outputs", allow_uneven_sharding=False)
     check_aval_layout_compatibility(
         out_layouts_flat, out_avals,
-        None if debug_info is None else debug_info.safe_result_paths(len(out_avals)),
+        debug_info.safe_result_paths(len(out_avals)),  # type: ignore[arg-type]
         "jit outputs")
   return out_shardings_flat, out_layouts_flat
 
diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py
index 7a3bc9bc8..0cf9a013a 100644
--- a/jax/extend/linear_util.py
+++ b/jax/extend/linear_util.py
@@ -15,6 +15,8 @@
 # Note: import <name> as <name> is required for names to be exported.
 # See PEP 484 & https://github.com/jax-ml/jax/issues/7570
 
+from typing import Callable
+
 from jax._src.linear_util import (
   StoreException as StoreException,
   WrappedFun as WrappedFun,
@@ -24,7 +26,14 @@ from jax._src.linear_util import (
   transformation_with_aux as transformation_with_aux,
   transformation2 as transformation2,
   transformation_with_aux2 as transformation_with_aux2,
-  wrap_init as wrap_init,
   # TODO(b/396086979): remove this once we pass debug_info everywhere.
+  wrap_init as _wrap_init,
   _missing_debug_info as _missing_debug_info,
 )
+
+# Version of wrap_init that does not require a DebugInfo object.
+# This usage is deprecated, use api_util.debug_info() to construct a proper
+# DebugInfo object.
+def wrap_init(f: Callable, params=None, *, debug_info=None) -> WrappedFun:
+  debug_info = debug_info or _missing_debug_info("linear_util.wrap_init")
+  return _wrap_init(f, params, debug_info=debug_info)
diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py
index bca62a2de..8a79f3867 100644
--- a/tests/debug_info_test.py
+++ b/tests/debug_info_test.py
@@ -71,8 +71,7 @@ def _collect_jaxprs(jaxpr: core.Jaxpr,
   return acc
 
 
-def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]:
-  if dbg is None: return "None"
+def _debug_info_to_string(dbg: core.DebugInfo) -> list[str]:
   # Strip the absolute path and the line number but check that it references
   # this file (to catch errors when the source info points in JAX internals)
   func_src_info = re.sub(r"^(\S+)( at .*.debug_info_test.py:\d+)?", "\\1", dbg.func_src_info)
@@ -294,7 +293,6 @@ class DebugInfoTest(jtu.JaxTestCase):
     def wrapper(x, y):
       return x
 
-    api_util.save_wrapped_fun_sourceinfo(wrapper, None)  # No effect
     dbg = api_util.debug_info("test", wrapper, (1, 2), {})
     self.assertEqual("wrapper", dbg.func_name)
 
diff --git a/tests/extend_test.py b/tests/extend_test.py
index e37bea42c..fcf9d3b54 100644
--- a/tests/extend_test.py
+++ b/tests/extend_test.py
@@ -53,7 +53,8 @@ class ExtendTest(jtu.JaxTestCase):
     self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux)
     self.assertIs(jex.linear_util.transformation, linear_util.transformation)
     self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux)
-    self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
+    # TODO(necula): revert this change once we deprecate the old wrap_init
+    # self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
 
 
 class RandomTest(jtu.JaxTestCase):