attach debug info to jaxpr, pass to mlir/mhlo

Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
Matthew Johnson 2023-03-02 09:58:14 -08:00
parent a002643a4a
commit 8440e27a5a
9 changed files with 146 additions and 90 deletions

View File

@ -401,9 +401,15 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
results.append((v.aval, 'from a constant'))
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.debug_info(f, in_tree, True, "saved_residuals")
arg_info = pe.arg_info_all(dbg)
for i, v in enumerate(jaxpr.invars):
if v in res_vars:
src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
if arg_info is not None:
arg_name, arg_path = arg_info[i]
src = f'from the argument {arg_name}{arg_path.pprint("")}'
else:
src = 'from the argument at flattened index {i}'
results.append((v.aval, src))
for eqn in jaxpr.eqns:

View File

@ -2978,10 +2978,12 @@ def make_jaxpr(fun: Callable,
in_type = tuple(zip(in_avals, keep_inputs))
f, out_tree = flatten_fun(f, in_tree)
f = lu.annotate(f, in_type)
debug_info = pe.debug_info(fun, in_tree, True, 'make_jaxpr')
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
f, debug_info=debug_info)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
if return_shape:
out_avals, _ = unzip2(out_type)

View File

@ -391,7 +391,7 @@ def initial_style_jaxpr(
def _initial_style_jaxpr(fun, in_tree, in_avals):
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun_, in_tree, False, 'checkify')
debug = pe.debug_info(fun, in_tree, False, 'checkify')
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
return jaxpr, consts, out_tree()

View File

@ -32,7 +32,7 @@ from typing import (Any, Callable, ClassVar, DefaultDict, Dict, FrozenSet,
NamedTuple, Optional, Sequence, Set, Tuple, Type,
Union, cast)
import warnings
from weakref import ref
from weakref import ref, ReferenceType
import numpy as np
@ -45,9 +45,11 @@ from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src.tree_util import PyTreeDef
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
tuple_delete, as_hashable_function,
HashableFunction, HashableWrapper, weakref_lru_cache)
HashableFunction, HashableWrapper, weakref_lru_cache,
fun_sourceinfo)
import jax._src.pretty_printer as pp
from jax._src.lib import jax_jit
from jax._src import traceback_util
@ -66,24 +68,43 @@ Effects = effects.Effects
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
class DebugInfo(NamedTuple):
fn_ref: ReferenceType[Callable] # function being traced/staged
in_tree: Optional[PyTreeDef] # caller/constructor might not have this info
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
traced_for: str # "jit", "scan", "make_jaxpr", etc
# TODO(mattjj): add input type signature
@property
def fn(self):
return self.fn_ref()
@property
def func_src_info(self) -> Optional[str]:
return fun_sourceinfo(self.fn)
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects']
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects', '_debug_info']
_constvars: List[Var]
_invars: List[Var]
_outvars: List[Atom]
_eqns: List[JaxprEqn]
_effects: Effects
_debug_info: Optional[DebugInfo]
constvars = property(lambda self: self._constvars)
invars = property(lambda self: self._invars)
outvars = property(lambda self: self._outvars)
eqns = property(lambda self: self._eqns)
effects = property(lambda self: self._effects)
debug_info = property(lambda self: self._debug_info)
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects):
effects: Effects = no_effects,
debug_info: Optional[DebugInfo] = None):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
@ -100,6 +121,7 @@ class Jaxpr:
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
def __str__(self):
return str(pp_jaxpr(self, JaxprPpContext(), JaxprPpSettings()))

View File

@ -666,7 +666,8 @@ def lower_jaxpr_to_module(
donated_args: Sequence[bool],
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
arg_names: Optional[Sequence[str]] = None,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
@ -743,7 +744,8 @@ def lower_jaxpr_to_module(
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_shardings, result_shardings=result_shardings,
input_output_aliases=input_output_aliases)
input_output_aliases=input_output_aliases,
arg_names=arg_names)
if not ctx.module.operation.verify():
module_string = module_to_string(ctx.module)
@ -860,6 +862,7 @@ def lower_jaxpr_to_fun(
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
num_output_tokens: int = 0,
api_name: str = 'jit',
arg_names: Optional[Sequence[str]] = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
@ -980,6 +983,10 @@ def lower_jaxpr_to_fun(
if alias is not None:
attrs["tf.aliasing_output"] = i32_attr(alias)
if config.jax_jit_pjit_api_merge and arg_names:
for attrs, name_ in zip(arg_attrs[num_dim_vars + num_tokens:], arg_names):
attrs['jax.arg_info'] = ir.StringAttr.get(name_)
func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])

View File

@ -3046,6 +3046,10 @@ def lower_sharding_computation(
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
arg_names = None if arg_info is None else [
f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info)
if i in kept_var_idx]
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -3058,7 +3062,8 @@ def lower_sharding_computation(
donated_invars,
replicated_args=replicated_args,
arg_shardings=in_op_shardings,
result_shardings=out_op_shardings)
result_shardings=out_op_shardings,
arg_names=arg_names)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,

View File

@ -14,6 +14,7 @@
import functools
from functools import partial
import inspect
import itertools as it
import logging
import operator
@ -535,3 +536,15 @@ def use_cpp_method(is_enabled=True):
"Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)"
)
return decorator
def fun_sourceinfo(fun: Callable) -> Optional[str]:
while isinstance(fun, partial):
fun = fun.func
fun = inspect.unwrap(fun)
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
return f"{fun.__name__} at {filename}:{lineno}"
except AttributeError:
return None

View File

@ -41,12 +41,13 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
raise_to_shaped, Atom, JaxprEqn, Primitive,
ShapedArray, DShapedArray, mapped_aval,
unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent)
InputType, OutputType, get_referent, DebugInfo)
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_leaves)
KeyPath, _generate_key_paths)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache)
as_hashable_function, weakref_lru_cache,
fun_sourceinfo)
map, unsafe_map = safe_map, map
@ -951,7 +952,7 @@ def tracers_to_jaxpr(
outvars = map(get_atom, out_tracers) # type: ignore[arg-type]
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type]
outvars, eqns, jaxpr_effects)
outvars, eqns, jaxpr_effects, None)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
# del getvar # needed to avoid cyclic-reference closure, apparently!
return jaxpr, const_vals, env_vals
@ -963,7 +964,7 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
@ -976,7 +977,7 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
constvars, invars = split_list(jaxpr.invars, [n])
lifted_jaxpr = Jaxpr(constvars=tuple(constvars), invars=invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
@ -1382,7 +1383,8 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
eqns = new_eqns[::-1]
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects)
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects,
jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
return new_jaxpr, used_inputs
@ -1486,12 +1488,22 @@ class DynamicJaxprTracer(core.Tracer):
if dbg is None:
return ""
origin = (f"The error occurred while tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}. ")
if invar_pos:
origin = ("The error occurred while tracing the function "
f"{fun_sourceinfo(dbg.fn)} for {dbg.traced_for}. ")
arg_info = arg_info_all(dbg)
if invar_pos and arg_info:
arg_info = [arg_info[i] for i in invar_pos]
arg_names = [f'{name}{path.pprint("")}' for name, path in arg_info]
if len(arg_names) == 1:
arg_info_str = f"the argument {arg_names[0]}"
elif len(arg_names) == 2:
arg_info_str = f"the arguments {arg_names[0]} and {arg_names[1]}"
else:
*rest, last = arg_names
arg_info_str = f"the arguments {', '.join(rest)}, and {last}"
origin += ("This concrete value was not available in Python because it "
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
f"of {dbg.arg_info(invar_pos)}.")
f"of {arg_info_str}.")
elif progenitor_eqns:
msts = [" operation "
f"{core.pp_eqn(eqn, core.JaxprPpContext(), core.JaxprPpSettings(print_shapes=True))}\n"
@ -1561,7 +1573,8 @@ class JaxprStackFrame:
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects,
self.debug_info)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
@ -1574,7 +1587,7 @@ class JaxprStackFrame:
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns,
jaxpr_effects)
jaxpr_effects, self.debug_info)
# We can't run check_jaxpr until after we normalize.
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
@ -1638,7 +1651,7 @@ def _const_folding_and_forwarding(jaxpr, constvals):
jaxpr_effects = make_jaxpr_effects(new_constvars, jaxpr.invars, new_outvars,
new_eqns)
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns,
jaxpr_effects)
jaxpr_effects, jaxpr.debug_info)
return new_jaxpr, new_constvals
ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
@ -1688,7 +1701,7 @@ def _inline_literals(jaxpr, constvals):
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
new_eqns)
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
jaxpr_effects)
jaxpr_effects, jaxpr.debug_info)
return new_jaxpr, new_constvals
class DynamicJaxprTrace(core.Trace):
@ -1828,7 +1841,8 @@ class DynamicJaxprTrace(core.Trace):
with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name))
f, self.main, reduced_in_avals,
debug_info=debug_info_final(f, map_primitive.name))
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
@ -1962,65 +1976,36 @@ def _memoize(thunk):
return memoized
class DebugInfo(NamedTuple):
func_src_info: str
traced_for: str
arg_info: Callable[[int], str]
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
traced_for: str) -> DebugInfo:
func_src_info = fun_sourceinfo(fn)
if in_tree is not None:
arg_info = partial(arg_info_pytree, fn, in_tree, has_kwargs)
else:
arg_info = arg_info_flattened # type: ignore
return DebugInfo(func_src_info, traced_for, arg_info)
return DebugInfo(ref(fn), in_tree, has_kwargs, traced_for)
def fun_sourceinfo(fun: Callable):
while isinstance(fun, functools.partial):
fun = fun.func
fun = inspect.unwrap(fun)
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
line_info = f"{fun.__name__} at {filename}:{lineno}"
return line_info
except AttributeError:
return "<unknown>"
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
"Make a DebugInfo from data available to final-style primitives like pmap."
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
return DebugInfo(ref(fn.f), in_tree, has_kwargs, traced_for)
def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
flat_pos: List[int]) -> str:
dummy_args = [False] * in_tree.num_leaves
for i in flat_pos: dummy_args[i] = True
if has_kwargs:
args, kwargs = tree_unflatten(in_tree, dummy_args)
else:
args, kwargs = tree_unflatten(in_tree, dummy_args), {}
def arg_info(dbg: DebugInfo, flat_pos: int) -> Optional[Tuple[str, KeyPath]]:
infos = arg_info_all(dbg)
return None if infos is None else infos[flat_pos]
def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]:
ba = None if dbg.in_tree is None else sig_info(dbg)
if ba is None: return None
return [(name, key_path) for name, dummy_arg in ba.arguments.items()
for key_path, _ in _generate_key_paths(dummy_arg)]
def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]:
if dbg.in_tree is None or dbg.fn is None: return None
try:
ba = inspect.signature(fn).bind(*args, **kwargs)
dummy_args = tree_unflatten(dbg.in_tree, [False] * dbg.in_tree.num_leaves)
except:
return None
args, kwargs = dummy_args if dbg.has_kwargs else (dummy_args, {})
try:
return inspect.signature(dbg.fn).bind(*args, **kwargs)
except (TypeError, ValueError):
return arg_info_flattened(flat_pos)
arg_names = [f"'{name}'" for name, x in ba.arguments.items()
if any(tree_leaves(x))]
if len(arg_names) == 1:
return f"the argument {arg_names[0]}"
elif len(arg_names) == 2:
return f"the arguments {arg_names[0]} and {arg_names[1]}"
else:
*rest, last = arg_names
return f"the arguments {', '.join(rest)}, and {last}"
def arg_info_flattened(flat_pos: List[int]) -> str:
if len(flat_pos) > 1:
return f"the argument passed at flattened positions {flat_pos}"
else:
return f"the argument passed at flattened position {flat_pos[0]}"
return None
@profiler.annotate_function
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
@ -2252,7 +2237,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]:
out_type = tuple(zip(out_avals, kept_outs))
new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns,
jaxpr.effects)
jaxpr.effects, jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
return new_jaxpr, out_type

View File

@ -1145,6 +1145,22 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIsNotNone(f.runtime_executable())
self.assertIsNotNone(g.runtime_executable())
def test_jit_lower_arg_info(self):
if not config.jax_array or not jax.config.jax_jit_pjit_api_merge:
raise unittest.SkipTest("test only applies after jit-pjit api merge")
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
mhlo_str = str(lowered.compiler_ir('mhlo'))
self.assertNotIn("\"x\"", mhlo_str)
self.assertIn("y['hi']", mhlo_str)
self.assertNotIn("args[0]", mhlo_str)
self.assertIn("args[1]", mhlo_str)
self.assertIn("kwargs['z']", mhlo_str)
self.assertIn("kwargs['w']", mhlo_str)
def test_jit_enum_as_dict_keys_fails(self):
class E(enum.Enum):
A = 0
@ -3389,7 +3405,7 @@ class APITest(jtu.JaxTestCase):
else:
return 0
msg = r"on the value of the argument 'x'"
msg = r"on the value of the argument x"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1)
@ -3401,7 +3417,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the arguments 'x' and 'y'"
msg = r"on the values of the arguments x and y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2)
@ -3413,7 +3429,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the arguments 'x' and 'z'"
msg = r"on the values of the arguments x and z"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
@ -3432,7 +3448,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the argument 'args'"
msg = r"on the values of the arguments args"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
@ -3445,7 +3461,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the argument 'kwargs'"
msg = r"on the values of the arguments kwargs"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(x=1, y=2, z=3)
@ -3458,7 +3474,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the value of the argument 'xy'"
msg = r"on the value of the argument xy"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f((1, 2), z=3)
@ -3491,7 +3507,7 @@ class APITest(jtu.JaxTestCase):
def g(x):
return f(x, True)
msg = r"on the value of the argument 'y'"
msg = r"on the value of the argument y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)
@ -5184,11 +5200,11 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(res[0][0].shape, (1,))
self.assertEqual(res[0][1], "from a constant")
self.assertEqual(res[1][0].shape, ())
self.assertEqual(res[1][1], "from the argument 'x'")
self.assertEqual(res[1][1], "from the argument x[0]")
self.assertEqual(res[2][0].shape, ())
self.assertEqual(res[2][1], "from the argument 'x'")
self.assertEqual(res[2][1], "from the argument x[1]")
self.assertEqual(res[3][0].shape, ())
self.assertEqual(res[3][1], "from the argument 'y'")
self.assertEqual(res[3][1], "from the argument y")
self.assertEqual(res[4][0].shape, ())
self.assertStartsWith(res[4][1], "named 'z'")
self.assertEqual(res[5][0].shape, ())