mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15085 from mattjj:arg-info-in-mlir-5
PiperOrigin-RevId: 518642948
This commit is contained in:
commit
dd2ecf4bb5
@ -59,7 +59,7 @@ from jax._src.api_util import (
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
shaped_abstractify, _ensure_str_tuple, argnames_partial_except,
|
||||
validate_argnames, validate_argnums, check_callable, resolve_argnums,
|
||||
FLAGS)
|
||||
debug_info, result_paths, debug_info_final, FLAGS)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -1634,6 +1634,8 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
if in_devices is not None and len(in_devices) == 0:
|
||||
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
||||
|
||||
dbg = debug_info('pmap', fun, args, kwargs, static_broadcasted_tuple, ())
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
if static_broadcasted_tuple:
|
||||
if max(static_broadcasted_tuple) >= len(args):
|
||||
@ -1671,7 +1673,9 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
kws=True))
|
||||
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
|
||||
|
||||
f, res_paths = result_paths(f)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
flat_fun = debug_info_final(flat_fun, dbg, res_paths)
|
||||
|
||||
if any(out_axis is None for out_axis in tree_flatten(out_axes)):
|
||||
raise NotImplementedError("None out_axes in pmap are not supported yet")
|
||||
|
@ -15,8 +15,8 @@
|
||||
import inspect
|
||||
import operator
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional,
|
||||
Sequence, Set, Tuple, Union,)
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence,
|
||||
Set, Tuple, Union)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -30,7 +30,9 @@ from jax._src.tree_util import (
|
||||
treedef_children, generate_key_paths, keystr)
|
||||
from jax._src.tree_util import _replace_nones
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable
|
||||
from jax._src.linear_util import TracingDebugInfo
|
||||
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
|
||||
Unhashable)
|
||||
from jax._src.config import flags, bool_env
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -585,14 +587,6 @@ def api_hook(fun, tag: str):
|
||||
return fun
|
||||
|
||||
|
||||
class TracingDebugInfo(NamedTuple):
|
||||
# Packages up trace/staging-time debug info about a func and its parameters,
|
||||
# formed just before staging to a jaxpr and read in trace-time error messages.
|
||||
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
||||
arg_names: Tuple[str, ...] # e.g. ('args[0]', ... )
|
||||
|
||||
def debug_info(traced_for: str, fun: Callable, args: Tuple[Any],
|
||||
kwargs: Dict[str, Any], static_argnums: Tuple[int, ...],
|
||||
static_argnames: Tuple[str, ...]) -> Optional[TracingDebugInfo]:
|
||||
@ -600,7 +594,7 @@ def debug_info(traced_for: str, fun: Callable, args: Tuple[Any],
|
||||
src = fun_sourceinfo(fun)
|
||||
arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames)
|
||||
if src is None or arg_names is None: return None
|
||||
return TracingDebugInfo(traced_for, src, arg_names)
|
||||
return TracingDebugInfo(traced_for, src, arg_names, None)
|
||||
|
||||
# TODO(mattjj): make this function internal to this module
|
||||
def fun_sourceinfo(fun: Callable) -> Optional[str]:
|
||||
@ -635,13 +629,23 @@ def result_paths(*args, **kwargs):
|
||||
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]
|
||||
|
||||
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
|
||||
result_paths: Optional[Tuple[Optional[str], ...]]
|
||||
result_paths: Optional[Tuple[Optional[str], ...]] = None,
|
||||
) -> core.Jaxpr:
|
||||
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
|
||||
if trace_debug is not None and result_paths is not None:
|
||||
debug_info = core.JaxprDebugInfo(
|
||||
trace_debug.traced_for, trace_debug.func_src_info,
|
||||
trace_debug.arg_names, result_paths)
|
||||
else:
|
||||
debug_info = None
|
||||
return jaxpr.replace(debug_info=debug_info) if debug_info else jaxpr
|
||||
if trace_debug is None:
|
||||
return jaxpr
|
||||
assert (result_paths is not None) ^ (trace_debug.result_paths is not None)
|
||||
if result_paths is None:
|
||||
result_paths = trace_debug.result_paths() # type: ignore
|
||||
debug_info = core.JaxprDebugInfo(
|
||||
trace_debug.traced_for, trace_debug.func_src_info,
|
||||
trace_debug.arg_names, result_paths)
|
||||
return jaxpr.replace(debug_info=debug_info)
|
||||
|
||||
def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo],
|
||||
res_paths: Callable[[], Tuple[str, ...]]) -> lu.WrappedFun:
|
||||
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
|
||||
if dbg is None: return f
|
||||
assert dbg.result_paths is None
|
||||
res_paths_ = HashableFunction(res_paths, closure=())
|
||||
return lu.add_debug_info(f, dbg._replace(result_paths=res_paths_))
|
||||
|
@ -1116,6 +1116,7 @@ def stage_parallel_callable(
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
assert len(out_sharded_avals) == len(pci.out_axes), (
|
||||
@ -1264,7 +1265,6 @@ def lower_parallel_callable(
|
||||
raise ValueError("Ordered effects not supported in `pmap`.")
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
arg_names, result_names = _debug_names(jaxpr.debug_info)
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -1278,7 +1278,8 @@ def lower_parallel_callable(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
|
||||
result_shardings=_shardings_to_mlir_shardings(parts.out_parts),
|
||||
arg_names=arg_names, result_names=result_names)
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks)
|
||||
@ -2564,7 +2565,6 @@ def lower_sharding_computation(
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
arg_names, result_names = _debug_names(jaxpr.debug_info)
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -2579,7 +2579,8 @@ def lower_sharding_computation(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_op_shardings,
|
||||
result_shardings=out_op_shardings,
|
||||
arg_names=arg_names, result_names=result_names)
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
|
||||
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
@ -2750,7 +2751,6 @@ def lower_mesh_computation(
|
||||
closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(
|
||||
closed_jaxpr.effects))
|
||||
arg_names, result_names = _debug_names(jaxpr.debug_info)
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -2764,8 +2764,8 @@ def lower_mesh_computation(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_partitions,
|
||||
result_shardings=out_partitions,
|
||||
arg_names=arg_names,
|
||||
result_names=result_names)
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks)
|
||||
@ -2791,12 +2791,6 @@ def lower_mesh_computation(
|
||||
device_assignment=list(mesh.devices.flat),
|
||||
committed=True)
|
||||
|
||||
def _debug_names(
|
||||
dbg: Optional[core.JaxprDebugInfo]
|
||||
) -> Union[Tuple[None, None],
|
||||
Tuple[Sequence[Optional[str]], Sequence[Optional[str]]]]:
|
||||
return (None, None) if dbg is None else (dbg.arg_names, dbg.result_paths)
|
||||
|
||||
class MeshComputation(stages.XlaLowering):
|
||||
_hlo: Optional[ir.Module]
|
||||
_executable: Optional[MeshExecutable]
|
||||
|
@ -64,7 +64,7 @@ data must be immutable, because it will be stored in function memoization tables
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Tuple, Callable
|
||||
from typing import Any, Tuple, Callable, Optional, NamedTuple
|
||||
import weakref
|
||||
|
||||
from jax.tree_util import tree_map
|
||||
@ -124,14 +124,15 @@ class WrappedFun:
|
||||
params: extra parameters to pass as keyword arguments to `f`, along with the
|
||||
transformed keyword arguments.
|
||||
"""
|
||||
__slots__ = ("f", "transforms", "stores", "params", "in_type")
|
||||
__slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info")
|
||||
|
||||
def __init__(self, f, transforms, stores, params, in_type):
|
||||
def __init__(self, f, transforms, stores, params, in_type, debug_info):
|
||||
self.f = f
|
||||
self.transforms = transforms
|
||||
self.stores = stores
|
||||
self.params = params
|
||||
self.in_type = in_type
|
||||
self.debug_info = debug_info
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
@ -140,7 +141,7 @@ class WrappedFun:
|
||||
def wrap(self, gen, gen_static_args, out_store) -> WrappedFun:
|
||||
"""Add another transform and its store."""
|
||||
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
|
||||
(out_store,) + self.stores, self.params, None)
|
||||
(out_store,) + self.stores, self.params, None, None)
|
||||
|
||||
def populate_stores(self, stores):
|
||||
"""Copy the values from the `stores` into `self.stores`."""
|
||||
@ -199,11 +200,13 @@ class WrappedFun:
|
||||
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.f, self.transforms, self.params, self.in_type))
|
||||
return hash((self.f, self.transforms, self.params, self.in_type,
|
||||
self.debug_info))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.f == other.f and self.transforms == other.transforms and
|
||||
self.params == other.params and self.in_type == other.in_type)
|
||||
self.params == other.params and self.in_type == other.in_type and
|
||||
self.debug_info == other.debug_info)
|
||||
|
||||
@curry
|
||||
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
@ -231,15 +234,15 @@ def fun_name(f):
|
||||
def wrap_init(f, params=None) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
return WrappedFun(f, (), (), params, None)
|
||||
return WrappedFun(f, (), (), params, None, None)
|
||||
|
||||
|
||||
def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
|
||||
def annotate(f: WrappedFun, in_type: Optional[core.InputType]) -> WrappedFun:
|
||||
assert f.in_type is None
|
||||
if in_type is None:
|
||||
return f
|
||||
_check_input_type(in_type)
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info)
|
||||
|
||||
def _check_input_type(in_type: core.InputType) -> None:
|
||||
# Check that in_type is syntactically well-formed
|
||||
@ -271,6 +274,24 @@ def _check_input_type(in_type: core.InputType) -> None:
|
||||
assert all(provided)
|
||||
|
||||
|
||||
class TracingDebugInfo(NamedTuple):
|
||||
# Packages up trace/staging-time debug info about a func and its parameters,
|
||||
# formed just before staging to a jaxpr and read in trace-time error messages.
|
||||
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
||||
arg_names: Tuple[str, ...] # e.g. ('args[0]', ... )
|
||||
result_paths: Optional[Callable[[], Tuple[str, ...]]]
|
||||
|
||||
def add_debug_info(f: WrappedFun, debug_info: Optional[TracingDebugInfo]
|
||||
) -> WrappedFun:
|
||||
"""Produce a new WrappedFun with debug_info attached."""
|
||||
assert f.debug_info is None
|
||||
if debug_info is None:
|
||||
return f
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info)
|
||||
|
||||
|
||||
def cache(call: Callable):
|
||||
"""Memoization decorator for functions taking a WrappedFun as first argument.
|
||||
|
||||
|
@ -2036,7 +2036,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
|
||||
def test_pmap_lower_arg_info(self):
|
||||
raise SkipTest("arg info not plumbed to pmap yet") # TODO(mattjj)
|
||||
def f(x, y, *args, **kwargs):
|
||||
return y['hi'] + args[1] + sum(kwargs.values())
|
||||
|
||||
@ -2052,7 +2051,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertIn("kwargs['w']", mhlo_str)
|
||||
|
||||
def test_pmap_lower_result_info(self):
|
||||
raise SkipTest("arg info not plumbed to pmap yet") # TODO(mattjj)
|
||||
def f(x, y, z):
|
||||
return {'a': x, 'b': [y]}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user