Merge pull request #15085 from mattjj:arg-info-in-mlir-5

PiperOrigin-RevId: 518642948
This commit is contained in:
jax authors 2023-03-22 12:31:35 -07:00
commit dd2ecf4bb5
5 changed files with 66 additions and 45 deletions

View File

@ -59,7 +59,7 @@ from jax._src.api_util import (
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, argnames_partial_except,
validate_argnames, validate_argnums, check_callable, resolve_argnums,
FLAGS)
debug_info, result_paths, debug_info_final, FLAGS)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
@ -1634,6 +1634,8 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
dbg = debug_info('pmap', fun, args, kwargs, static_broadcasted_tuple, ())
f = lu.wrap_init(fun)
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
@ -1671,7 +1673,9 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
kws=True))
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
f, res_paths = result_paths(f)
flat_fun, out_tree = flatten_fun(f, in_tree)
flat_fun = debug_info_final(flat_fun, dbg, res_paths)
if any(out_axis is None for out_axis in tree_flatten(out_axes)):
raise NotImplementedError("None out_axes in pmap are not supported yet")

View File

@ -15,8 +15,8 @@
import inspect
import operator
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional,
Sequence, Set, Tuple, Union,)
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence,
Set, Tuple, Union)
import warnings
import numpy as np
@ -30,7 +30,9 @@ from jax._src.tree_util import (
treedef_children, generate_key_paths, keystr)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable)
from jax._src.config import flags, bool_env
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
@ -585,14 +587,6 @@ def api_hook(fun, tag: str):
return fun
class TracingDebugInfo(NamedTuple):
# Packages up trace/staging-time debug info about a func and its parameters,
# formed just before staging to a jaxpr and read in trace-time error messages.
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: Tuple[str, ...] # e.g. ('args[0]', ... )
def debug_info(traced_for: str, fun: Callable, args: Tuple[Any],
kwargs: Dict[str, Any], static_argnums: Tuple[int, ...],
static_argnames: Tuple[str, ...]) -> Optional[TracingDebugInfo]:
@ -600,7 +594,7 @@ def debug_info(traced_for: str, fun: Callable, args: Tuple[Any],
src = fun_sourceinfo(fun)
arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames)
if src is None or arg_names is None: return None
return TracingDebugInfo(traced_for, src, arg_names)
return TracingDebugInfo(traced_for, src, arg_names, None)
# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> Optional[str]:
@ -635,13 +629,23 @@ def result_paths(*args, **kwargs):
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
result_paths: Optional[Tuple[Optional[str], ...]]
result_paths: Optional[Tuple[Optional[str], ...]] = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is not None and result_paths is not None:
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, result_paths)
else:
debug_info = None
return jaxpr.replace(debug_info=debug_info) if debug_info else jaxpr
if trace_debug is None:
return jaxpr
assert (result_paths is not None) ^ (trace_debug.result_paths is not None)
if result_paths is None:
result_paths = trace_debug.result_paths() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, result_paths)
return jaxpr.replace(debug_info=debug_info)
def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo],
res_paths: Callable[[], Tuple[str, ...]]) -> lu.WrappedFun:
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
if dbg is None: return f
assert dbg.result_paths is None
res_paths_ = HashableFunction(res_paths, closure=())
return lu.add_debug_info(f, dbg._replace(result_paths=res_paths_))

View File

@ -1116,6 +1116,7 @@ def stage_parallel_callable(
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
assert len(out_sharded_avals) == len(pci.out_axes), (
@ -1264,7 +1265,6 @@ def lower_parallel_callable(
raise ValueError("Ordered effects not supported in `pmap`.")
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
arg_names, result_names = _debug_names(jaxpr.debug_info)
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -1278,7 +1278,8 @@ def lower_parallel_callable(
replicated_args=replicated_args,
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
result_shardings=_shardings_to_mlir_shardings(parts.out_parts),
arg_names=arg_names, result_names=result_names)
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
@ -2564,7 +2565,6 @@ def lower_sharding_computation(
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
arg_names, result_names = _debug_names(jaxpr.debug_info)
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -2579,7 +2579,8 @@ def lower_sharding_computation(
replicated_args=replicated_args,
arg_shardings=in_op_shardings,
result_shardings=out_op_shardings,
arg_names=arg_names, result_names=result_names)
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
@ -2750,7 +2751,6 @@ def lower_mesh_computation(
closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(
closed_jaxpr.effects))
arg_names, result_names = _debug_names(jaxpr.debug_info)
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -2764,8 +2764,8 @@ def lower_mesh_computation(
replicated_args=replicated_args,
arg_shardings=in_partitions,
result_shardings=out_partitions,
arg_names=arg_names,
result_names=result_names)
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
@ -2791,12 +2791,6 @@ def lower_mesh_computation(
device_assignment=list(mesh.devices.flat),
committed=True)
def _debug_names(
dbg: Optional[core.JaxprDebugInfo]
) -> Union[Tuple[None, None],
Tuple[Sequence[Optional[str]], Sequence[Optional[str]]]]:
return (None, None) if dbg is None else (dbg.arg_names, dbg.result_paths)
class MeshComputation(stages.XlaLowering):
_hlo: Optional[ir.Module]
_executable: Optional[MeshExecutable]

View File

@ -64,7 +64,7 @@ data must be immutable, because it will be stored in function memoization tables
from __future__ import annotations
from functools import partial
from typing import Any, Tuple, Callable
from typing import Any, Tuple, Callable, Optional, NamedTuple
import weakref
from jax.tree_util import tree_map
@ -124,14 +124,15 @@ class WrappedFun:
params: extra parameters to pass as keyword arguments to `f`, along with the
transformed keyword arguments.
"""
__slots__ = ("f", "transforms", "stores", "params", "in_type")
__slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info")
def __init__(self, f, transforms, stores, params, in_type):
def __init__(self, f, transforms, stores, params, in_type, debug_info):
self.f = f
self.transforms = transforms
self.stores = stores
self.params = params
self.in_type = in_type
self.debug_info = debug_info
@property
def __name__(self):
@ -140,7 +141,7 @@ class WrappedFun:
def wrap(self, gen, gen_static_args, out_store) -> WrappedFun:
"""Add another transform and its store."""
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None)
(out_store,) + self.stores, self.params, None, None)
def populate_stores(self, stores):
"""Copy the values from the `stores` into `self.stores`."""
@ -199,11 +200,13 @@ class WrappedFun:
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
def __hash__(self):
return hash((self.f, self.transforms, self.params, self.in_type))
return hash((self.f, self.transforms, self.params, self.in_type,
self.debug_info))
def __eq__(self, other):
return (self.f == other.f and self.transforms == other.transforms and
self.params == other.params and self.in_type == other.in_type)
self.params == other.params and self.in_type == other.in_type and
self.debug_info == other.debug_info)
@curry
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
@ -231,15 +234,15 @@ def fun_name(f):
def wrap_init(f, params=None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, (), (), params, None)
return WrappedFun(f, (), (), params, None, None)
def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
def annotate(f: WrappedFun, in_type: Optional[core.InputType]) -> WrappedFun:
assert f.in_type is None
if in_type is None:
return f
_check_input_type(in_type)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info)
def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
@ -271,6 +274,24 @@ def _check_input_type(in_type: core.InputType) -> None:
assert all(provided)
class TracingDebugInfo(NamedTuple):
# Packages up trace/staging-time debug info about a func and its parameters,
# formed just before staging to a jaxpr and read in trace-time error messages.
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: Tuple[str, ...] # e.g. ('args[0]', ... )
result_paths: Optional[Callable[[], Tuple[str, ...]]]
def add_debug_info(f: WrappedFun, debug_info: Optional[TracingDebugInfo]
) -> WrappedFun:
"""Produce a new WrappedFun with debug_info attached."""
assert f.debug_info is None
if debug_info is None:
return f
return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info)
def cache(call: Callable):
"""Memoization decorator for functions taking a WrappedFun as first argument.

View File

@ -2036,7 +2036,6 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' cos '), 2)
def test_pmap_lower_arg_info(self):
raise SkipTest("arg info not plumbed to pmap yet") # TODO(mattjj)
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
@ -2052,7 +2051,6 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertIn("kwargs['w']", mhlo_str)
def test_pmap_lower_result_info(self):
raise SkipTest("arg info not plumbed to pmap yet") # TODO(mattjj)
def f(x, y, z):
return {'a': x, 'b': [y]}