add result info to mhlo, fixes #14780

incidentally fixes #14787
This commit is contained in:
Matthew Johnson 2023-03-06 20:15:38 -08:00
parent 0ec82f4d62
commit b05975b964
14 changed files with 115 additions and 39 deletions

View File

@ -23,7 +23,7 @@ import jax
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten, keystr
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
from jax._src import ad_util
from jax._src import core
from jax._src import linear_util as lu
@ -357,8 +357,8 @@ def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
# remat-specific errors.
@weakref_lru_cache
def _trace_to_jaxpr(fun, in_tree, in_avals):
debug = pe.debug_info(fun, in_tree, True, "checkpoint")
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint")
try:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
except core.ConcretizationTypeError as e:
@ -386,8 +386,12 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])(
*in_leaves).jaxpr
out = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1],
return_shape=True)(*in_leaves)
assert isinstance(out, tuple)
jaxpr_, out_shape = out
jaxpr = jaxpr_.jaxpr
out_tree = lambda: tree_structure(out_shape)
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}
@ -401,7 +405,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
results.append((v.aval, 'from a constant'))
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.debug_info(f, in_tree, True, "saved_residuals")
dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals")
arg_info = pe.arg_info_all(dbg)
for i, v in enumerate(jaxpr.invars):
if v in res_vars:

View File

@ -2909,7 +2909,8 @@ def make_jaxpr(fun: Callable,
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
return_shape: bool = False,
abstracted_axes: Optional[Any] = None,
) -> Callable[..., core.ClosedJaxpr]:
) -> Callable[..., Union[core.ClosedJaxpr,
Tuple[core.ClosedJaxpr, Any]]]:
"""Creates a function that produces its jaxpr given example args.
Args:
@ -2990,7 +2991,7 @@ def make_jaxpr(fun: Callable,
in_type = tuple(zip(in_avals, keep_inputs))
f, out_tree = flatten_fun(f, in_tree)
f = lu.annotate(f, in_type)
debug_info = pe.debug_info(fun, in_tree, True, 'make_jaxpr')
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
@ -3342,7 +3343,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
"""
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug_info = pe.debug_info(fun, in_tree, True, "eval_shape")
debug_info = pe.debug_info(fun, in_tree, out_tree, True, "eval_shape")
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
*map(shaped_abstractify, args_flat),
debug_info=debug_info)

View File

@ -109,7 +109,9 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
ans = fun(*args)
return tree_unflatten(out_tree, ans)
def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
def flattened_fun_in_tree(
fn: lu.WrappedFun
) -> Optional[Tuple[PyTreeDef, Callable[[], PyTreeDef], bool]]:
# This implementation relies on internal details of linear_util.py's
# WrappedFun, but it's for the worthy cause of better user error messages.
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
@ -119,14 +121,15 @@ def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
assert (isinstance(flatten_fun_nokwargs, partial) and
len(flatten_fun_nokwargs.args) == 1)
flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
flattens = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
try:
(in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
for f, args in fn.transforms if f in flat_xforms)
((in_tree,), out_tree_store, has_kwargs), = (
(args, store, f is flatten_fun.args[0])
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
except ValueError:
return None
else:
return in_tree, has_kwargs
return in_tree, lambda: out_tree_store.val, has_kwargs
@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):

View File

@ -391,7 +391,7 @@ def initial_style_jaxpr(
def _initial_style_jaxpr(fun, in_tree, in_avals):
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, False, 'checkify')
debug = pe.debug_info(fun, in_tree, out_tree, False, 'checkify')
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
return jaxpr, consts, out_tree()

View File

@ -71,6 +71,7 @@ class DebugInfo(NamedTuple):
func_src_info: Optional[str] # f'{fun.__name__} at {filename}:{lineno}'
signature: Optional[inspect.Signature] # inspect.signature(fun)
in_tree: Optional[PyTreeDef] # caller/constructor might not have this info
out_tree: Optional[Callable[[], PyTreeDef]] # lazy, not avail at trace time
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
traced_for: str # "jit", "scan", "make_jaxpr", etc
# TODO(mattjj): add input type signature

View File

@ -65,7 +65,7 @@ class custom_vmap:
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))

View File

@ -668,6 +668,7 @@ def lower_jaxpr_to_module(
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
arg_names: Optional[Sequence[str]] = None,
result_names: Optional[Sequence[str]] = None,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
@ -745,7 +746,7 @@ def lower_jaxpr_to_module(
replicated_args=replicated_args,
arg_shardings=arg_shardings, result_shardings=result_shardings,
input_output_aliases=input_output_aliases,
arg_names=arg_names)
arg_names=arg_names, result_names=result_names)
if not ctx.module.operation.verify():
module_string = module_to_string(ctx.module)
@ -863,6 +864,7 @@ def lower_jaxpr_to_fun(
num_output_tokens: int = 0,
api_name: str = 'jit',
arg_names: Optional[Sequence[str]] = None,
result_names: Optional[Sequence[str]] = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
@ -992,13 +994,29 @@ def lower_jaxpr_to_fun(
func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
result_attrs: List[Dict[str, ir.Attribute]] = [
{} for _ in range(len(flat_output_types))]
if config.jax_jit_pjit_api_merge and result_names:
named_result_attrs = result_attrs[num_tokens:]
if len(named_result_attrs) == len(result_names):
for attrs, name_ in zip(named_result_attrs, result_names):
attrs['jax.result_info'] = ir.StringAttr.get(name_)
if use_sharding_annotations and ir_result_shardings is not None:
func_op.result_attrs = ir.ArrayAttr.get([
ir.DictAttr.get(
{} if sharding is None else
{"mhlo.sharding": get_sharding_attr(sharding)}
) for sharding in ir_result_shardings
])
for attrs, sharding in zip(result_attrs, ir_result_shardings):
if sharding is not None:
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
# func_op.result_attrs = ir.ArrayAttr.get([
# ir.DictAttr.get(
# {} if sharding is None else
# {"mhlo.sharding": get_sharding_attr(sharding)}
# ) for sharding in ir_result_shardings
# ])
func_op.result_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in result_attrs])
entry_block = func_op.add_entry_block()
with ir.InsertionPoint(entry_block):

View File

@ -3063,10 +3063,7 @@ def lower_sharding_computation(
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
arg_names = None if arg_info is None else [
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)
if i in kept_var_idx]
arg_names, result_names = _debug_names(jaxpr.debug_info, kept_var_idx)
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -3081,7 +3078,7 @@ def lower_sharding_computation(
replicated_args=replicated_args,
arg_shardings=in_op_shardings,
result_shardings=out_op_shardings,
arg_names=arg_names)
arg_names=arg_names, result_names=result_names)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
@ -3254,9 +3251,7 @@ def lower_mesh_computation(
closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(
closed_jaxpr.effects))
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
arg_names = None if arg_info is None else [
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)]
arg_names, result_names = _debug_names(jaxpr.debug_info)
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -3270,7 +3265,8 @@ def lower_mesh_computation(
replicated_args=replicated_args,
arg_shardings=in_partitions,
result_shardings=out_partitions,
arg_names=arg_names)
arg_names=arg_names,
result_names=result_names)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
@ -3297,6 +3293,17 @@ def lower_mesh_computation(
device_assignment=list(mesh.devices.flat),
committed=True)
def _debug_names(
dbg: Optional[core.DebugInfo], kept_var_idx: Optional[Set] = None
) -> Tuple[Optional[List[str]], Optional[List[str]]]:
if dbg is None: return (None, None)
arg_info = pe.arg_info_all(dbg)
arg_names = None if arg_info is None else [
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)
if kept_var_idx is None or i in kept_var_idx]
result_info = pe.result_info(dbg)
result_names = None if result_info is None else map(keystr, result_info)
return arg_names, result_names
class MeshComputation(stages.XlaLowering):
_hlo: Optional[ir.Module]

View File

@ -49,7 +49,8 @@ def _typecheck_param(prim, param, name, msg_required, pred):
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
debug = pe.debug_info(fun, in_tree, out_tree, False,
primitive_name or "<unknown>")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
return jaxpr, consts, out_tree()

View File

@ -346,7 +346,8 @@ class custom_partitioning:
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, False, "custom_partitioning")
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
"custom_partitioning")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())

View File

@ -2311,7 +2311,7 @@ def bcoo_conv_general_dilated(lhs, rhs, *, window_strides, padding,
precision=precision, preferred_element_type=preferred_element_type)
jaxpr = jax.make_jaxpr(func)(jax.ShapeDtypeStruct(lhs.shape, lhs.dtype),
jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
assert len(jaxpr.eqns) == 1
assert isinstance(jaxpr, core.ClosedJaxpr) and len(jaxpr.eqns) == 1
params = jaxpr.eqns[0].params
if params['lhs_dilation'] != (1,) * (lhs.ndim - 2):

View File

@ -1976,17 +1976,19 @@ def _memoize(thunk):
return memoized
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
traced_for: str) -> DebugInfo:
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef],
out_tree_thunk: Optional[Callable[[], PyTreeDef]],
has_kwargs: bool, traced_for: str) -> DebugInfo:
try: sig = inspect.signature(fn)
except (ValueError, TypeError): sig = None
src_info = fun_sourceinfo(fn)
return DebugInfo(src_info, sig, in_tree, has_kwargs, traced_for)
return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs,
traced_for)
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
"Make a DebugInfo from data available to final-style primitives like pmap."
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]:
ba = None if dbg.in_tree is None else sig_info(dbg)
@ -2006,6 +2008,16 @@ def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]:
except (TypeError, ValueError):
return None
def result_info(dbg: DebugInfo) -> Optional[List[KeyPath]]:
if dbg.out_tree is None: return None
try:
num_leaves = dbg.out_tree().num_leaves
dummy_result = tree_unflatten(dbg.out_tree(), [False] * num_leaves)
except:
return None
else:
return [path for path, _ in _generate_key_paths(dummy_result)]
@profiler.annotate_function
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],

View File

@ -1161,6 +1161,18 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIn("kwargs['z']", mhlo_str)
self.assertIn("kwargs['w']", mhlo_str)
def test_jit_lower_result_info(self):
if not config.jax_array or not jax.config.jax_jit_pjit_api_merge:
raise unittest.SkipTest("test only applies after jit-pjit api merge")
def f(x, y, z):
return {'a': x, 'b': [y]}
lowered = jax.jit(f).lower(1., (2,), [3])
mhlo_str = str(lowered.compiler_ir('mhlo'))
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
def test_jit_enum_as_dict_keys_fails(self):
class E(enum.Enum):
A = 0

View File

@ -3476,6 +3476,22 @@ class ArrayPjitTest(jtu.JaxTestCase):
# self.assertNotIn("args[0]", mhlo_str)
self.assertIn("args[1]", mhlo_str)
@jtu.with_mesh([('x', 2), ('y', 1)])
def test_jit_nested_xmap_lower_result_info(self):
if not config.jax_array or not jax.config.jax_jit_pjit_api_merge:
raise unittest.SkipTest("test only applies after jit-pjit api merge")
def f(x, y, z):
_ = xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
axis_resources={'i': 'y'})(jnp.arange(8.))
return {'a': x, 'b': [y]}
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
1., (2.,), [3.])
mhlo_str = str(lowered.compiler_ir('mhlo'))
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
def test_with_sharding_constraint_with_two_meshes(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")