mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often that usage is error reporting code, which is not yet well tested. When we cannot get the `inspect.Signature` or when the args and kwargs do not match the signature, we generate the flattened argument names as: `args[0]`, `args[1]`, `kwargs['foo']`, ... Previously, in these cases we returned `arg_names` is None, and then the whole debug_info ended up being `None`, throwing away even available information. We also add support for `api_util.fun_sourceinfo` even for cases when the `fun.__code__` is not available. In those cases we used to say that `fun_sourceinfo` is `None`. Now, we use the string representation of `fun` to get the name of built-in functions, or we use "<unknown>".
This commit is contained in:
parent
e41f4caa3e
commit
3f73f7b0eb
@ -454,9 +454,10 @@ def saved_residuals(f, *args, **kwargs) -> list[tuple[core.AbstractValue, str]]:
|
||||
out_tree = lambda: tree_structure(out_shape)
|
||||
assert len(jaxpr.invars) == len(in_leaves)
|
||||
dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals")
|
||||
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
|
||||
return _saved_residuals(jaxpr, dbg.arg_names)
|
||||
|
||||
def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
|
||||
def _saved_residuals(jaxpr: core.Jaxpr,
|
||||
arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]:
|
||||
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
|
||||
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}
|
||||
|
||||
@ -471,7 +472,7 @@ def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
|
||||
|
||||
for i, v in enumerate(jaxpr.invars):
|
||||
if v in res_vars:
|
||||
if arg_names is not None:
|
||||
if arg_names[i] is not None:
|
||||
src = f'from the argument {arg_names[i]}'
|
||||
else:
|
||||
src = 'from the argument at flattened index {i}'
|
||||
@ -587,7 +588,8 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
|
||||
_, staged_unk = partition_list(in_used_staged, in_unknowns)
|
||||
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
|
||||
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
|
||||
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
|
||||
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars),
|
||||
("",) * len(jaxpr_known.invars))
|
||||
logger.log(log_level,
|
||||
'remat-decorated function ' +
|
||||
'saving inputs with shapes:\n' * bool(res_invars) +
|
||||
|
@ -18,6 +18,7 @@ from collections.abc import Callable, Iterable, Sequence
|
||||
import inspect
|
||||
import operator
|
||||
from functools import partial, lru_cache
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from jax._src import core
|
||||
@ -603,15 +604,13 @@ def tracing_debug_info(
|
||||
# TODO(necula): check if we really need this, e.g., to speed up tracing.
|
||||
sourceinfo: str | None = None,
|
||||
signature: inspect.Signature | None = None,
|
||||
) -> TracingDebugInfo | None:
|
||||
) -> TracingDebugInfo:
|
||||
if sourceinfo is None:
|
||||
sourceinfo = fun_sourceinfo(fun)
|
||||
if signature is None:
|
||||
signature = fun_signature(fun)
|
||||
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
|
||||
static_argnames)
|
||||
if arg_names is None:
|
||||
return None
|
||||
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
|
||||
|
||||
|
||||
@ -624,12 +623,13 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
|
||||
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):
|
||||
# Prefer this to functools.wraps because it does not create a reference to
|
||||
# the wrapped function.
|
||||
sourceinfo = fun_sourceinfo(wrapped)
|
||||
if sourceinfo is not None:
|
||||
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
|
||||
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
|
||||
|
||||
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
|
||||
|
||||
# TODO(mattjj): make this function internal to this module
|
||||
def fun_sourceinfo(fun: Callable) -> str | None:
|
||||
def fun_sourceinfo(fun: Callable) -> str:
|
||||
# See TracingDebugInfo.fun_src_info
|
||||
res = getattr(fun, "__fun_sourceinfo__", None)
|
||||
if res is not None: return res
|
||||
while isinstance(fun, partial):
|
||||
@ -639,28 +639,51 @@ def fun_sourceinfo(fun: Callable) -> str | None:
|
||||
filename = fun.__code__.co_filename
|
||||
lineno = fun.__code__.co_firstlineno
|
||||
return f"{fun.__name__} at {filename}:{lineno}"
|
||||
except AttributeError:
|
||||
return None
|
||||
except AttributeError as e:
|
||||
try:
|
||||
fun_str = str(fun)
|
||||
except:
|
||||
return "<unknown>"
|
||||
# By contract, the function name has no spaces; also, we want to avoid
|
||||
# fun_sourceinfo of the form "<object Foo at 0x1234>", because it makes
|
||||
# lowering non-deterministic.
|
||||
if m := _fun_name_re.match(fun_str):
|
||||
return m.group(1)
|
||||
return "<unknown>"
|
||||
|
||||
|
||||
# TODO(necula): this should never return None
|
||||
def _non_static_arg_names(fn_signature: inspect.Signature | None,
|
||||
args: Sequence[Any], kwargs: dict[str, Any],
|
||||
static_argnums: Sequence[int],
|
||||
static_argnames: Sequence[str],
|
||||
) -> tuple[str | None, ...] | None:
|
||||
if fn_signature is None:
|
||||
return None
|
||||
) -> tuple[str | None, ...]:
|
||||
"""Returns the names of the non-static arguments.
|
||||
|
||||
If the `fn_signature` is given then we get from it the names of the
|
||||
top-level arguments. In other cases, including when the `args` and `kwargs`
|
||||
do not match the signature, we use names like `args[0[]`, `args[1]`, etc.
|
||||
"""
|
||||
static = object()
|
||||
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
|
||||
static_argnames_ = set(static_argnames)
|
||||
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
|
||||
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
|
||||
try:
|
||||
ba = fn_signature.bind(*args_, **kwargs)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
|
||||
for path, l in generate_key_paths(x) if l is not static)
|
||||
kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
|
||||
if fn_signature is not None:
|
||||
try:
|
||||
ba = fn_signature.bind(*args_, **kwargs_)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
|
||||
for path, l in generate_key_paths(x) if l is not static)
|
||||
args_arg_names = tuple(f'args{keystr(path)}'
|
||||
for path, l in generate_key_paths(args_)
|
||||
if l is not static)
|
||||
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
|
||||
for path, l in generate_key_paths(kwargs_)
|
||||
if l is not static)
|
||||
arg_names = args_arg_names + kwargs_arg_names
|
||||
return arg_names
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def result_paths(_fun, _store, *args, **kwargs):
|
||||
|
@ -77,12 +77,17 @@ Effects = effects.Effects
|
||||
EffectTypeSet = effects.EffectTypeSet
|
||||
no_effects: Effects = effects.no_effects
|
||||
|
||||
|
||||
# TODO(necula): make this an extension of TracingDebugInfo
|
||||
class JaxprDebugInfo(NamedTuple):
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
||||
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
|
||||
# An extension of lu.TracingDebugInfo; see comments there
|
||||
traced_for: str
|
||||
func_src_info: str
|
||||
arg_names: tuple[str | None, ...]
|
||||
# This is formed after tracing, when we have concrete `result_paths`
|
||||
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
|
||||
|
||||
|
||||
class Jaxpr:
|
||||
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
|
||||
'_effects', '_debug_info']
|
||||
@ -140,7 +145,7 @@ class Jaxpr:
|
||||
self._eqns = list(eqns)
|
||||
self._effects = effects
|
||||
self._debug_info = debug_info
|
||||
assert (not debug_info or debug_info.arg_names is None or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
|
||||
assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
|
||||
assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
|
||||
|
||||
def __str__(self):
|
||||
|
@ -1545,7 +1545,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
return ""
|
||||
|
||||
origin = ("The error occurred while tracing the function "
|
||||
f"{dbg.func_src_info or '<unknown>'} for {dbg.traced_for}. ")
|
||||
f"{dbg.func_src_info} for {dbg.traced_for}. ")
|
||||
if invar_pos and dbg.arg_names:
|
||||
try:
|
||||
arg_names = [dbg.arg_names[i] for i in invar_pos]
|
||||
@ -2116,7 +2116,7 @@ def tracing_debug_info(
|
||||
out_tree_thunk: Callable[[], PyTreeDef],
|
||||
has_kwargs: bool,
|
||||
traced_for: str
|
||||
) -> lu.TracingDebugInfo | None:
|
||||
) -> lu.TracingDebugInfo:
|
||||
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
|
||||
# We just have to make sure we grad the debugging information when we have
|
||||
# the unflattened args
|
||||
|
@ -259,7 +259,10 @@ class TracingDebugInfo(NamedTuple):
|
||||
Formed just before staging to a jaxpr and read in trace-time error messages.
|
||||
"""
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
||||
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
|
||||
# no source location information. The first word is always the function name,
|
||||
# which may be '<unknown>'.
|
||||
func_src_info: str
|
||||
|
||||
# The paths of the flattened non-static argnames,
|
||||
# e.g. ('x', 'dict_arg["a"]', ... ).
|
||||
|
@ -75,6 +75,7 @@ class CompilerParams(Protocol):
|
||||
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
|
||||
|
||||
|
||||
# TODO(necula): clean up the splitting of the fun_sourceinfo
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NameAndSrcInfo:
|
||||
#: The name of the pallas_call or the name of the kernel function.
|
||||
@ -108,9 +109,12 @@ class NameAndSrcInfo:
|
||||
if pallas_call_name is not None:
|
||||
return NameAndSrcInfo(pallas_call_name,
|
||||
f"for kernel function {src_info}")
|
||||
src_info_parts = src_info.split(" ")
|
||||
return NameAndSrcInfo(src_info_parts[0],
|
||||
" ".join(src_info_parts[1:]))
|
||||
src_info_parts = src_info.split(" at ")
|
||||
if len(src_info_parts) > 1:
|
||||
return NameAndSrcInfo(src_info_parts[0],
|
||||
"at " + " ".join(src_info_parts[1:]))
|
||||
else:
|
||||
return NameAndSrcInfo(src_info_parts[0], "")
|
||||
|
||||
|
||||
split_list = util.split_list
|
||||
|
@ -1814,7 +1814,7 @@ def pallas_call(
|
||||
"pallas_call kernel",
|
||||
kernel,
|
||||
[1] * len(kernel_fun_sig.parameters), {})
|
||||
arg_names = kernel_debug_info and kernel_debug_info.arg_names
|
||||
arg_names = kernel_debug_info.arg_names
|
||||
del kernel_debug_info
|
||||
in_origins = tuple(in_path_to_input_origin(p, arg_names)
|
||||
for p in in_paths)
|
||||
|
@ -143,7 +143,7 @@ class PjitInfo(NamedTuple):
|
||||
In other words, this structure contains arguments to jit()/pjit(),
|
||||
preprocessed and validated.
|
||||
"""
|
||||
fun_sourceinfo: str | None
|
||||
fun_sourceinfo: str
|
||||
fun_signature: inspect.Signature | None
|
||||
# Shardings, as specified by the user. These can either be UNSPECIFIED or they
|
||||
# can be a tree (prefix) of shardings or None.
|
||||
@ -537,7 +537,7 @@ class PjitParams(NamedTuple):
|
||||
in_tree: PyTreeDef
|
||||
out_tree: PyTreeDef
|
||||
donated_invars: tuple[bool, ...]
|
||||
arg_names: tuple[str | None, ...] | None
|
||||
arg_names: tuple[str | None, ...]
|
||||
num_consts: int
|
||||
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
|
||||
abstract_mesh: AbstractMesh
|
||||
@ -1189,7 +1189,8 @@ def explain_tracing_cache_miss(
|
||||
# have we seen this function before at all?
|
||||
fun_name = getattr(f, '__qualname__', f)
|
||||
if debug_info is not None and debug_info.func_src_info:
|
||||
_, _, *rest = debug_info.func_src_info.split(' ')
|
||||
# TODO(necula): clean up the extraction of the source info
|
||||
_, *rest = debug_info.func_src_info.split(' at ')
|
||||
src_info = " defined at " + ' '.join(rest)
|
||||
else:
|
||||
src_info = ''
|
||||
|
@ -22,6 +22,7 @@ import operator
|
||||
from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
@ -40,6 +41,98 @@ jtu.request_cpu_devices(8)
|
||||
|
||||
class DebugInfoTest(jtu.JaxTestCase):
|
||||
|
||||
def test_debug_info_basic(self):
|
||||
def my_f(x, y, z, w):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
|
||||
self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+")
|
||||
self.assertEqual(dbg.arg_names, ("x", "y", "z", "w"))
|
||||
self.assertIsNone(dbg.result_paths_thunk)
|
||||
|
||||
def test_debug_info_arg_passed_as_kwarg(self):
|
||||
def my_f(x, y, z):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3))
|
||||
self.assertEqual(dbg.arg_names, ("x", "y", "z"))
|
||||
|
||||
def test_debug_info_pytrees(self):
|
||||
def my_f(x_tree, *, y_tree):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),),
|
||||
dict(y_tree=dict(z=3, w=4)))
|
||||
self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]",
|
||||
"y_tree['w']", "y_tree['z']"))
|
||||
|
||||
def test_debug_info_with_statics(self):
|
||||
def my_f(x, y, *, z, w):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
|
||||
static_argnums=(1,),
|
||||
static_argnames=("w",))
|
||||
self.assertEqual(dbg.arg_names, ("x", "z"))
|
||||
|
||||
def test_debug_info_with_pytrees_and_statics(self):
|
||||
def my_f(x, y, *, z, w):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)),
|
||||
dict(z=(3, 4), w=(5, 6)),
|
||||
static_argnums=(1,),
|
||||
static_argnames=("w",))
|
||||
self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]"))
|
||||
|
||||
def test_debug_info_too_many_args(self):
|
||||
def my_f(x):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3))
|
||||
self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']"))
|
||||
|
||||
def test_debug_info_no_source_info_built_in(self):
|
||||
# built-in function "int" does not have an inspect.Signature
|
||||
dbg = api_util.tracing_debug_info("jit", max, (1,), {})
|
||||
self.assertEqual(dbg.func_src_info, "max")
|
||||
self.assertEqual(dbg.arg_names, ("args[0]",))
|
||||
|
||||
def test_debug_info_lambda(self):
|
||||
# built-in function "int" does not have an inspect.Signature
|
||||
dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {})
|
||||
self.assertRegex(dbg.func_src_info, r"^<lambda> at .*debug_info_test.py:\d+")
|
||||
self.assertEqual(dbg.arg_names, ("my_arg",))
|
||||
|
||||
def test_debug_info_no_source_info_not_callable(self):
|
||||
# built-in function "int" does not have an inspect.Signature
|
||||
dbg = api_util.tracing_debug_info("jit", False, (1,), {})
|
||||
self.assertEqual(dbg.func_src_info, "<unknown>")
|
||||
self.assertEqual(dbg.arg_names, ("args[0]",))
|
||||
|
||||
def test_debug_info_no_source_info_callable(self):
|
||||
class Foo:
|
||||
x: int
|
||||
def __call__(self, y):
|
||||
return self.x + y
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
|
||||
self.assertRegex(dbg.func_src_info, "<unknown>")
|
||||
self.assertEqual(dbg.arg_names, ("y",))
|
||||
|
||||
def test_debug_info_no_source_info_callable_with_repr_errors(self):
|
||||
class Foo:
|
||||
x: int
|
||||
def __call__(self, y):
|
||||
return self.x + y
|
||||
|
||||
def __repr__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
|
||||
self.assertRegex(dbg.func_src_info, "<unknown>")
|
||||
self.assertEqual(dbg.arg_names, ("y",))
|
||||
|
||||
def helper_save_tracer(self, x):
|
||||
self._saved_tracer = x
|
||||
return x
|
||||
|
@ -966,9 +966,8 @@ class ApiErrorTest(PallasBaseTest):
|
||||
in_specs=[pl.BlockSpec((4,), my_index_map)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
# TODO(necula): the function name should be "my_index_map"
|
||||
"Index map function unknown .* "
|
||||
"must return 1 values to match .*"
|
||||
"Index map function my_index_map at .*pallas_test.py.* "
|
||||
"for x_ref must return 1 values to match .*"
|
||||
"Currently returning 2 values."):
|
||||
f(a)
|
||||
|
||||
@ -982,9 +981,8 @@ class ApiErrorTest(PallasBaseTest):
|
||||
in_specs=[pl.BlockSpec((4,), my_index_map)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
# TODO(necula): the function name should be "my_index_map"
|
||||
"Index map function unknown .* "
|
||||
"must return integer scalars. Output\\[0\\] has "
|
||||
"Index map function my_index_map at .*pallas_test.py.* "
|
||||
"for x_ref must return integer scalars. Output\\[0\\] has "
|
||||
"type .*float"):
|
||||
f(a)
|
||||
|
||||
@ -998,9 +996,8 @@ class ApiErrorTest(PallasBaseTest):
|
||||
in_specs=[pl.BlockSpec((4,), my_index_map)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
# TODO(necula): the function name should be "my_index_map"
|
||||
"Index map function unknown .* "
|
||||
"must return integer scalars. Output\\[0\\] has "
|
||||
"Index map function my_index_map at .*pallas_test.py.* "
|
||||
"for x_ref must return integer scalars. Output\\[0\\] has "
|
||||
"type .*int32\\[4\\]"):
|
||||
f(a)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user