[better_errors] Ensure debug_info.arg_names is never None.

Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.

When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.

We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
This commit is contained in:
George Necula 2025-01-20 17:17:44 +01:00
parent e41f4caa3e
commit 3f73f7b0eb
10 changed files with 175 additions and 47 deletions

View File

@ -454,9 +454,10 @@ def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
out_tree = lambda: tree_structure(out_shape)
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals")
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
return _saved_residuals(jaxpr, dbg.arg_names)
def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
def _saved_residuals(jaxpr: core.Jaxpr,
arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]:
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}
@ -471,7 +472,7 @@ def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
for i, v in enumerate(jaxpr.invars):
if v in res_vars:
if arg_names is not None:
if arg_names[i] is not None:
src = f'from the argument {arg_names[i]}'
else:
src = 'from the argument at flattened index {i}'
@ -587,7 +588,8 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
_, staged_unk = partition_list(in_used_staged, in_unknowns)
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars),
("",) * len(jaxpr_known.invars))
logger.log(log_level,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +

View File

@ -18,6 +18,7 @@ from collections.abc import Callable, Iterable, Sequence
import inspect
import operator
from functools import partial, lru_cache
import re
from typing import Any
from jax._src import core
@ -603,15 +604,13 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo | None:
) -> TracingDebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
@ -624,12 +623,13 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):
# Prefer this to functools.wraps because it does not create a reference to
# the wrapped function.
sourceinfo = fun_sourceinfo(wrapped)
if sourceinfo is not None:
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str | None:
def fun_sourceinfo(fun: Callable) -> str:
# See TracingDebugInfo.fun_src_info
res = getattr(fun, "__fun_sourceinfo__", None)
if res is not None: return res
while isinstance(fun, partial):
@ -639,28 +639,51 @@ def fun_sourceinfo(fun: Callable) -> str | None:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
return f"{fun.__name__} at {filename}:{lineno}"
except AttributeError:
return None
except AttributeError as e:
try:
fun_str = str(fun)
except:
return "<unknown>"
# By contract, the function name has no spaces; also, we want to avoid
# fun_sourceinfo of the form "<object Foo at 0x1234>", because it makes
# lowering non-deterministic.
if m := _fun_name_re.match(fun_str):
return m.group(1)
return "<unknown>"
# TODO(necula): this should never return None
def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None:
return None
) -> tuple[str | None, ...]:
"""Returns the names of the non-static arguments.
If the `fn_signature` is given then we get from it the names of the
top-level arguments. In other cases, including when the `args` and `kwargs`
do not match the signature, we use names like `args[0[]`, `args[1]`, etc.
"""
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
try:
ba = fn_signature.bind(*args_, **kwargs)
except (ValueError, TypeError):
return None
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
if fn_signature is not None:
try:
ba = fn_signature.bind(*args_, **kwargs_)
except (ValueError, TypeError):
pass
else:
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
args_arg_names = tuple(f'args{keystr(path)}'
for path, l in generate_key_paths(args_)
if l is not static)
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
for path, l in generate_key_paths(kwargs_)
if l is not static)
arg_names = args_arg_names + kwargs_arg_names
return arg_names
@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):

View File

@ -77,12 +77,17 @@ Effects = effects.Effects
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
# TODO(necula): make this an extension of TracingDebugInfo
class JaxprDebugInfo(NamedTuple):
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
# An extension of lu.TracingDebugInfo; see comments there
traced_for: str
func_src_info: str
arg_names: tuple[str | None, ...]
# This is formed after tracing, when we have concrete `result_paths`
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
'_effects', '_debug_info']
@ -140,7 +145,7 @@ class Jaxpr:
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
assert (not debug_info or debug_info.arg_names is None or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
def __str__(self):

View File

@ -1545,7 +1545,7 @@ class DynamicJaxprTracer(core.Tracer):
return ""
origin = ("The error occurred while tracing the function "
f"{dbg.func_src_info or '<unknown>'} for {dbg.traced_for}. ")
f"{dbg.func_src_info} for {dbg.traced_for}. ")
if invar_pos and dbg.arg_names:
try:
arg_names = [dbg.arg_names[i] for i in invar_pos]
@ -2116,7 +2116,7 @@ def tracing_debug_info(
out_tree_thunk: Callable[[], PyTreeDef],
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo | None:
) -> lu.TracingDebugInfo:
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
# We just have to make sure we grad the debugging information when we have
# the unflattened args

View File

@ -259,7 +259,10 @@ class TracingDebugInfo(NamedTuple):
Formed just before staging to a jaxpr and read in trace-time error messages.
"""
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
# no source location information. The first word is always the function name,
# which may be '<unknown>'.
func_src_info: str
# The paths of the flattened non-static argnames,
# e.g. ('x', 'dict_arg["a"]', ... ).

View File

@ -75,6 +75,7 @@ class CompilerParams(Protocol):
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
# TODO(necula): clean up the splitting of the fun_sourceinfo
@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:
#: The name of the pallas_call or the name of the kernel function.
@ -108,9 +109,12 @@ class NameAndSrcInfo:
if pallas_call_name is not None:
return NameAndSrcInfo(pallas_call_name,
f"for kernel function {src_info}")
src_info_parts = src_info.split(" ")
return NameAndSrcInfo(src_info_parts[0],
" ".join(src_info_parts[1:]))
src_info_parts = src_info.split(" at ")
if len(src_info_parts) > 1:
return NameAndSrcInfo(src_info_parts[0],
"at " + " ".join(src_info_parts[1:]))
else:
return NameAndSrcInfo(src_info_parts[0], "")
split_list = util.split_list

View File

@ -1814,7 +1814,7 @@ def pallas_call(
"pallas_call kernel",
kernel,
[1] * len(kernel_fun_sig.parameters), {})
arg_names = kernel_debug_info and kernel_debug_info.arg_names
arg_names = kernel_debug_info.arg_names
del kernel_debug_info
in_origins = tuple(in_path_to_input_origin(p, arg_names)
for p in in_paths)

View File

@ -143,7 +143,7 @@ class PjitInfo(NamedTuple):
In other words, this structure contains arguments to jit()/pjit(),
preprocessed and validated.
"""
fun_sourceinfo: str | None
fun_sourceinfo: str
fun_signature: inspect.Signature | None
# Shardings, as specified by the user. These can either be UNSPECIFIED or they
# can be a tree (prefix) of shardings or None.
@ -537,7 +537,7 @@ class PjitParams(NamedTuple):
in_tree: PyTreeDef
out_tree: PyTreeDef
donated_invars: tuple[bool, ...]
arg_names: tuple[str | None, ...] | None
arg_names: tuple[str | None, ...]
num_consts: int
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
abstract_mesh: AbstractMesh
@ -1189,7 +1189,8 @@ def explain_tracing_cache_miss(
# have we seen this function before at all?
fun_name = getattr(f, '__qualname__', f)
if debug_info is not None and debug_info.func_src_info:
_, _, *rest = debug_info.func_src_info.split(' ')
# TODO(necula): clean up the extraction of the source info
_, *rest = debug_info.func_src_info.split(' at ')
src_info = " defined at " + ' '.join(rest)
else:
src_info = ''

View File

@ -22,6 +22,7 @@ import operator
from absl.testing import absltest, parameterized
import jax
from jax import lax
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
@ -40,6 +41,98 @@ jtu.request_cpu_devices(8)
class DebugInfoTest(jtu.JaxTestCase):
def test_debug_info_basic(self):
def my_f(x, y, z, w):
pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("x", "y", "z", "w"))
self.assertIsNone(dbg.result_paths_thunk)
def test_debug_info_arg_passed_as_kwarg(self):
def my_f(x, y, z):
pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3))
self.assertEqual(dbg.arg_names, ("x", "y", "z"))
def test_debug_info_pytrees(self):
def my_f(x_tree, *, y_tree):
pass
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),),
dict(y_tree=dict(z=3, w=4)))
self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]",
"y_tree['w']", "y_tree['z']"))
def test_debug_info_with_statics(self):
def my_f(x, y, *, z, w):
pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x", "z"))
def test_debug_info_with_pytrees_and_statics(self):
def my_f(x, y, *, z, w):
pass
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)),
dict(z=(3, 4), w=(5, 6)),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]"))
def test_debug_info_too_many_args(self):
def my_f(x):
pass
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3))
self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']"))
def test_debug_info_no_source_info_built_in(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", max, (1,), {})
self.assertEqual(dbg.func_src_info, "max")
self.assertEqual(dbg.arg_names, ("args[0]",))
def test_debug_info_lambda(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {})
self.assertRegex(dbg.func_src_info, r"^<lambda> at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("my_arg",))
def test_debug_info_no_source_info_not_callable(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", False, (1,), {})
self.assertEqual(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("args[0]",))
def test_debug_info_no_source_info_callable(self):
class Foo:
x: int
def __call__(self, y):
return self.x + y
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",))
def test_debug_info_no_source_info_callable_with_repr_errors(self):
class Foo:
x: int
def __call__(self, y):
return self.x + y
def __repr__(self):
raise NotImplementedError
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",))
def helper_save_tracer(self, x):
self._saved_tracer = x
return x

View File

@ -966,9 +966,8 @@ class ApiErrorTest(PallasBaseTest):
in_specs=[pl.BlockSpec((4,), my_index_map)])
with self.assertRaisesRegex(
ValueError,
# TODO(necula): the function name should be "my_index_map"
"Index map function unknown .* "
"must return 1 values to match .*"
"Index map function my_index_map at .*pallas_test.py.* "
"for x_ref must return 1 values to match .*"
"Currently returning 2 values."):
f(a)
@ -982,9 +981,8 @@ class ApiErrorTest(PallasBaseTest):
in_specs=[pl.BlockSpec((4,), my_index_map)])
with self.assertRaisesRegex(
ValueError,
# TODO(necula): the function name should be "my_index_map"
"Index map function unknown .* "
"must return integer scalars. Output\\[0\\] has "
"Index map function my_index_map at .*pallas_test.py.* "
"for x_ref must return integer scalars. Output\\[0\\] has "
"type .*float"):
f(a)
@ -998,9 +996,8 @@ class ApiErrorTest(PallasBaseTest):
in_specs=[pl.BlockSpec((4,), my_index_map)])
with self.assertRaisesRegex(
ValueError,
# TODO(necula): the function name should be "my_index_map"
"Index map function unknown .* "
"must return integer scalars. Output\\[0\\] has "
"Index map function my_index_map at .*pallas_test.py.* "
"for x_ref must return integer scalars. Output\\[0\\] has "
"type .*int32\\[4\\]"):
f(a)