Merge pull request #14764 from mattjj:arg-info-in-mlir

PiperOrigin-RevId: 513686779
This commit is contained in:
jax authors 2023-03-02 17:45:11 -08:00
commit afdcd44c96
9 changed files with 144 additions and 86 deletions

View File

@ -401,9 +401,15 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
results.append((v.aval, 'from a constant'))
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.debug_info(f, in_tree, True, "saved_residuals")
arg_info = pe.arg_info_all(dbg)
for i, v in enumerate(jaxpr.invars):
if v in res_vars:
src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
if arg_info is not None:
arg_name, arg_path = arg_info[i]
src = f'from the argument {arg_name}{arg_path.pprint("")}'
else:
src = 'from the argument at flattened index {i}'
results.append((v.aval, src))
for eqn in jaxpr.eqns:

View File

@ -2984,10 +2984,12 @@ def make_jaxpr(fun: Callable,
in_type = tuple(zip(in_avals, keep_inputs))
f, out_tree = flatten_fun(f, in_tree)
f = lu.annotate(f, in_type)
debug_info = pe.debug_info(fun, in_tree, True, 'make_jaxpr')
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
f, debug_info=debug_info)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
if return_shape:
out_avals, _ = unzip2(out_type)

View File

@ -391,7 +391,7 @@ def initial_style_jaxpr(
def _initial_style_jaxpr(fun, in_tree, in_avals):
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun_, in_tree, False, 'checkify')
debug = pe.debug_info(fun, in_tree, False, 'checkify')
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
return jaxpr, consts, out_tree()

View File

@ -45,6 +45,7 @@ from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src.tree_util import PyTreeDef
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
tuple_delete, as_hashable_function,
HashableFunction, HashableWrapper, weakref_lru_cache)
@ -66,24 +67,36 @@ Effects = effects.Effects
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
class DebugInfo(NamedTuple):
func_src_info: Optional[str] # f'{fun.__name__} at {filename}:{lineno}'
signature: Optional[inspect.Signature] # inspect.signature(fun)
in_tree: Optional[PyTreeDef] # caller/constructor might not have this info
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
traced_for: str # "jit", "scan", "make_jaxpr", etc
# TODO(mattjj): add input type signature
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects']
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects', '_debug_info']
_constvars: List[Var]
_invars: List[Var]
_outvars: List[Atom]
_eqns: List[JaxprEqn]
_effects: Effects
_debug_info: Optional[DebugInfo]
constvars = property(lambda self: self._constvars)
invars = property(lambda self: self._invars)
outvars = property(lambda self: self._outvars)
eqns = property(lambda self: self._eqns)
effects = property(lambda self: self._effects)
debug_info = property(lambda self: self._debug_info)
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects):
effects: Effects = no_effects,
debug_info: Optional[DebugInfo] = None):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
@ -100,6 +113,7 @@ class Jaxpr:
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
def __str__(self):
return str(pp_jaxpr(self, JaxprPpContext(), JaxprPpSettings()))

View File

@ -666,7 +666,8 @@ def lower_jaxpr_to_module(
donated_args: Sequence[bool],
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
arg_names: Optional[Sequence[str]] = None,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
@ -743,7 +744,8 @@ def lower_jaxpr_to_module(
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_shardings, result_shardings=result_shardings,
input_output_aliases=input_output_aliases)
input_output_aliases=input_output_aliases,
arg_names=arg_names)
if not ctx.module.operation.verify():
module_string = module_to_string(ctx.module)
@ -860,6 +862,7 @@ def lower_jaxpr_to_fun(
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
num_output_tokens: int = 0,
api_name: str = 'jit',
arg_names: Optional[Sequence[str]] = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
@ -980,6 +983,12 @@ def lower_jaxpr_to_fun(
if alias is not None:
attrs["tf.aliasing_output"] = i32_attr(alias)
if config.jax_jit_pjit_api_merge and arg_names:
named_arg_attrs = arg_attrs[num_dim_vars + num_tokens:]
if len(named_arg_attrs) == len(arg_names):
for attrs, name_ in zip(named_arg_attrs, arg_names):
attrs['jax.arg_info'] = ir.StringAttr.get(name_)
func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])

View File

@ -3052,6 +3052,10 @@ def lower_sharding_computation(
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
arg_names = None if arg_info is None else [
f'{name}{path.pprint("")}' for i, (name, path) in enumerate(arg_info)
if i in kept_var_idx]
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -3064,7 +3068,8 @@ def lower_sharding_computation(
donated_invars,
replicated_args=replicated_args,
arg_shardings=in_op_shardings,
result_shardings=out_op_shardings)
result_shardings=out_op_shardings,
arg_names=arg_names)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,

View File

@ -14,6 +14,7 @@
import functools
from functools import partial
import inspect
import itertools as it
import logging
import operator
@ -535,3 +536,15 @@ def use_cpp_method(is_enabled=True):
"Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)"
)
return decorator
def fun_sourceinfo(fun: Callable) -> Optional[str]:
while isinstance(fun, partial):
fun = fun.func
fun = inspect.unwrap(fun)
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
return f"{fun.__name__} at {filename}:{lineno}"
except AttributeError:
return None

View File

@ -41,12 +41,13 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
raise_to_shaped, Atom, JaxprEqn, Primitive,
ShapedArray, DShapedArray, mapped_aval,
unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent)
InputType, OutputType, get_referent, DebugInfo)
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_leaves)
KeyPath, _generate_key_paths)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache)
as_hashable_function, weakref_lru_cache,
fun_sourceinfo)
map, unsafe_map = safe_map, map
@ -951,7 +952,7 @@ def tracers_to_jaxpr(
outvars = map(get_atom, out_tracers) # type: ignore[arg-type]
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type]
outvars, eqns, jaxpr_effects)
outvars, eqns, jaxpr_effects, None)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
# del getvar # needed to avoid cyclic-reference closure, apparently!
return jaxpr, const_vals, env_vals
@ -963,7 +964,7 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
@ -976,7 +977,7 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
constvars, invars = split_list(jaxpr.invars, [n])
lifted_jaxpr = Jaxpr(constvars=tuple(constvars), invars=invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
@ -1382,7 +1383,8 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
eqns = new_eqns[::-1]
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects)
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects,
jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
return new_jaxpr, used_inputs
@ -1486,12 +1488,22 @@ class DynamicJaxprTracer(core.Tracer):
if dbg is None:
return ""
origin = (f"The error occurred while tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}. ")
if invar_pos:
origin = ("The error occurred while tracing the function "
f"{dbg.func_src_info or '<unknown>'} for {dbg.traced_for}. ")
arg_info = arg_info_all(dbg)
if invar_pos and arg_info:
arg_info = [arg_info[i] for i in invar_pos]
arg_names = [f'{name}{path.pprint("")}' for name, path in arg_info]
if len(arg_names) == 1:
arg_info_str = f"the argument {arg_names[0]}"
elif len(arg_names) == 2:
arg_info_str = f"the arguments {arg_names[0]} and {arg_names[1]}"
else:
*rest, last = arg_names
arg_info_str = f"the arguments {', '.join(rest)}, and {last}"
origin += ("This concrete value was not available in Python because it "
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
f"of {dbg.arg_info(invar_pos)}.")
f"of {arg_info_str}.")
elif progenitor_eqns:
msts = [" operation "
f"{core.pp_eqn(eqn, core.JaxprPpContext(), core.JaxprPpSettings(print_shapes=True))}\n"
@ -1561,7 +1573,8 @@ class JaxprStackFrame:
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects,
self.debug_info)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
@ -1574,7 +1587,7 @@ class JaxprStackFrame:
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns,
jaxpr_effects)
jaxpr_effects, self.debug_info)
# We can't run check_jaxpr until after we normalize.
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
@ -1638,7 +1651,7 @@ def _const_folding_and_forwarding(jaxpr, constvals):
jaxpr_effects = make_jaxpr_effects(new_constvars, jaxpr.invars, new_outvars,
new_eqns)
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns,
jaxpr_effects)
jaxpr_effects, jaxpr.debug_info)
return new_jaxpr, new_constvals
ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
@ -1688,7 +1701,7 @@ def _inline_literals(jaxpr, constvals):
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
new_eqns)
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
jaxpr_effects)
jaxpr_effects, jaxpr.debug_info)
return new_jaxpr, new_constvals
class DynamicJaxprTrace(core.Trace):
@ -1828,7 +1841,8 @@ class DynamicJaxprTrace(core.Trace):
with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name))
f, self.main, reduced_in_avals,
debug_info=debug_info_final(f, map_primitive.name))
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
@ -1962,65 +1976,35 @@ def _memoize(thunk):
return memoized
class DebugInfo(NamedTuple):
func_src_info: str
traced_for: str
arg_info: Callable[[int], str]
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
traced_for: str) -> DebugInfo:
try: sig = inspect.signature(fn)
except (ValueError, TypeError): sig = None
src_info = fun_sourceinfo(fn)
return DebugInfo(src_info, sig, in_tree, has_kwargs, traced_for)
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
"Make a DebugInfo from data available to final-style primitives like pmap."
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
traced_for: str) -> DebugInfo:
func_src_info = fun_sourceinfo(fn)
if in_tree is not None:
arg_info = partial(arg_info_pytree, fn, in_tree, has_kwargs)
else:
arg_info = arg_info_flattened # type: ignore
return DebugInfo(func_src_info, traced_for, arg_info)
def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]:
ba = None if dbg.in_tree is None else sig_info(dbg)
if ba is None: return None
return [(name, key_path) for name, dummy_arg in ba.arguments.items()
for key_path, _ in _generate_key_paths(dummy_arg)]
def fun_sourceinfo(fun: Callable):
while isinstance(fun, functools.partial):
fun = fun.func
fun = inspect.unwrap(fun)
def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]:
if dbg.in_tree is None or dbg.signature is None: return None
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
line_info = f"{fun.__name__} at {filename}:{lineno}"
return line_info
except AttributeError:
return "<unknown>"
def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
flat_pos: List[int]) -> str:
dummy_args = [False] * in_tree.num_leaves
for i in flat_pos: dummy_args[i] = True
if has_kwargs:
args, kwargs = tree_unflatten(in_tree, dummy_args)
else:
args, kwargs = tree_unflatten(in_tree, dummy_args), {}
dummy_args = tree_unflatten(dbg.in_tree, [False] * dbg.in_tree.num_leaves)
except:
return None
args, kwargs = dummy_args if dbg.has_kwargs else (dummy_args, {})
try:
ba = inspect.signature(fn).bind(*args, **kwargs)
return dbg.signature.bind(*args, **kwargs)
except (TypeError, ValueError):
return arg_info_flattened(flat_pos)
arg_names = [f"'{name}'" for name, x in ba.arguments.items()
if any(tree_leaves(x))]
if len(arg_names) == 1:
return f"the argument {arg_names[0]}"
elif len(arg_names) == 2:
return f"the arguments {arg_names[0]} and {arg_names[1]}"
else:
*rest, last = arg_names
return f"the arguments {', '.join(rest)}, and {last}"
def arg_info_flattened(flat_pos: List[int]) -> str:
if len(flat_pos) > 1:
return f"the argument passed at flattened positions {flat_pos}"
else:
return f"the argument passed at flattened position {flat_pos[0]}"
return None
@profiler.annotate_function
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
@ -2252,7 +2236,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]:
out_type = tuple(zip(out_avals, kept_outs))
new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns,
jaxpr.effects)
jaxpr.effects, jaxpr.debug_info)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
return new_jaxpr, out_type

View File

@ -28,7 +28,7 @@ import re
import subprocess
import sys
import types
from typing import Callable, List, Optional
from typing import Callable, List, Optional, NamedTuple
import unittest
import warnings
import weakref
@ -1145,6 +1145,22 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIsNotNone(f.runtime_executable())
self.assertIsNotNone(g.runtime_executable())
def test_jit_lower_arg_info(self):
if not config.jax_array or not jax.config.jax_jit_pjit_api_merge:
raise unittest.SkipTest("test only applies after jit-pjit api merge")
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
mhlo_str = str(lowered.compiler_ir('mhlo'))
self.assertNotIn("\"x\"", mhlo_str)
self.assertIn("y['hi']", mhlo_str)
self.assertNotIn("args[0]", mhlo_str)
self.assertIn("args[1]", mhlo_str)
self.assertIn("kwargs['z']", mhlo_str)
self.assertIn("kwargs['w']", mhlo_str)
def test_jit_enum_as_dict_keys_fails(self):
class E(enum.Enum):
A = 0
@ -3389,7 +3405,7 @@ class APITest(jtu.JaxTestCase):
else:
return 0
msg = r"on the value of the argument 'x'"
msg = r"on the value of the argument x"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1)
@ -3401,7 +3417,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the arguments 'x' and 'y'"
msg = r"on the values of the arguments x and y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2)
@ -3413,7 +3429,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the arguments 'x' and 'z'"
msg = r"on the values of the arguments x and z"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
@ -3432,7 +3448,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the argument 'args'"
msg = r"on the values of the arguments args"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
@ -3445,7 +3461,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the values of the argument 'kwargs'"
msg = r"on the values of the arguments kwargs"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(x=1, y=2, z=3)
@ -3458,7 +3474,7 @@ class APITest(jtu.JaxTestCase):
else:
return y
msg = r"on the value of the argument 'xy'"
msg = r"on the value of the argument xy"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f((1, 2), z=3)
@ -3491,7 +3507,7 @@ class APITest(jtu.JaxTestCase):
def g(x):
return f(x, True)
msg = r"on the value of the argument 'y'"
msg = r"on the value of the argument y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)
@ -4164,6 +4180,15 @@ class APITest(jtu.JaxTestCase):
jax.ShapeDtypeStruct((8, 2), np.float32,
sharding=jax.sharding.PartitionSpec('x'))
def test_make_jaxpr_weakref(self):
class Foo(NamedTuple):
x: int
def __call__(self, y):
return self.x + y
jax.make_jaxpr(Foo(1))(3) # don't crash
class RematTest(jtu.JaxTestCase):
@ -5184,11 +5209,11 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(res[0][0].shape, (1,))
self.assertEqual(res[0][1], "from a constant")
self.assertEqual(res[1][0].shape, ())
self.assertEqual(res[1][1], "from the argument 'x'")
self.assertEqual(res[1][1], "from the argument x[0]")
self.assertEqual(res[2][0].shape, ())
self.assertEqual(res[2][1], "from the argument 'x'")
self.assertEqual(res[2][1], "from the argument x[1]")
self.assertEqual(res[3][0].shape, ())
self.assertEqual(res[3][1], "from the argument 'y'")
self.assertEqual(res[3][1], "from the argument y")
self.assertEqual(res[4][0].shape, ())
self.assertStartsWith(res[4][1], "named 'z'")
self.assertEqual(res[5][0].shape, ())