mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[better_errors] Merge the JaxprDebugInfo and TracingDebugInfo into core.DebugInfo
Previously, we had two almost identical classes: `TracingDebugInfo` and `JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had a thunk to return the result paths, while `JaxprDebugInfo` had the result paths resolved to a tuple. The separation of these types provided some clarity, but also led to code duplication and required conversions as the debugging info goes from `WrappedFun` to a `Jaxpr` and then to `WrappedFun` again.
This commit is contained in:
parent
de48ce2a4c
commit
c70de6deed
@ -323,7 +323,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def fun_remat(*args, **kwargs):
|
||||
debug = api_util.tracing_debug_info(
|
||||
debug = api_util.debug_info(
|
||||
"checkpoint / remat", fun,
|
||||
args, kwargs, static_argnums=static_argnums)
|
||||
fun_, args = _remat_static_argnums(fun, static_argnums, args)
|
||||
@ -418,7 +418,7 @@ _dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached)
|
||||
def _trace_to_jaxpr(fun: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug: lu.TracingDebugInfo
|
||||
debug: core.DebugInfo
|
||||
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
|
||||
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
try:
|
||||
@ -447,7 +447,7 @@ def saved_residuals(f: Callable,
|
||||
args, kwargs = tree_unflatten(in_tree, args)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs)
|
||||
debug_info = api_util.debug_info("saved_residuals", f, args, kwargs)
|
||||
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
|
||||
return_shape=True)(*in_leaves)
|
||||
assert isinstance(out, tuple)
|
||||
|
@ -57,11 +57,11 @@ from jax._src import pjit
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
||||
flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
|
||||
result_paths, flat_out_axes)
|
||||
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
||||
flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
apply_flat_fun_nokwargs, check_callable, debug_info,
|
||||
result_paths, flat_out_axes)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -452,7 +452,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0,
|
||||
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
|
||||
f"{max_argnum + 1} positional arguments to be passed by the caller, "
|
||||
f"but got only {len(args)} positional arguments.")
|
||||
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs)
|
||||
dbg = debug_info('value_and_grad', fun, args, kwargs)
|
||||
|
||||
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
@ -1426,7 +1426,7 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
|
||||
if in_devices is not None and len(in_devices) == 0:
|
||||
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
||||
|
||||
dbg = tracing_debug_info(
|
||||
dbg = debug_info(
|
||||
"pmap", fun, args, kwargs,
|
||||
static_argnums=static_broadcasted_tuple)
|
||||
|
||||
|
@ -31,7 +31,6 @@ from jax._src.tree_util import (
|
||||
prefix_errors)
|
||||
from jax._src.tree_util import _replace_nones
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.linear_util import TracingDebugInfo
|
||||
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
|
||||
Unhashable, safe_zip)
|
||||
from jax._src import traceback_util
|
||||
@ -582,7 +581,7 @@ def api_hook(fun, tag: str):
|
||||
return fun
|
||||
|
||||
|
||||
def tracing_debug_info(
|
||||
def debug_info(
|
||||
traced_for: str,
|
||||
fun: Callable,
|
||||
args: Sequence[Any],
|
||||
@ -594,14 +593,14 @@ def tracing_debug_info(
|
||||
# TODO(necula): check if we really need this, e.g., to speed up tracing.
|
||||
sourceinfo: str | None = None,
|
||||
signature: inspect.Signature | None = None,
|
||||
) -> TracingDebugInfo:
|
||||
) -> core.DebugInfo:
|
||||
if sourceinfo is None:
|
||||
sourceinfo = fun_sourceinfo(fun)
|
||||
if signature is None:
|
||||
signature = fun_signature(fun)
|
||||
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
|
||||
static_argnames)
|
||||
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
|
||||
return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
|
||||
|
||||
|
||||
def fun_signature(fun: Callable) -> inspect.Signature | None:
|
||||
@ -619,7 +618,7 @@ _fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
|
||||
|
||||
# TODO(mattjj): make this function internal to this module
|
||||
def fun_sourceinfo(fun: Callable) -> str:
|
||||
# See TracingDebugInfo.fun_src_info
|
||||
# See DebugInfo.fun_src_info
|
||||
res = getattr(fun, "__fun_sourceinfo__", None)
|
||||
if res is not None: return res
|
||||
while isinstance(fun, partial):
|
||||
@ -684,20 +683,19 @@ def result_paths(_fun, _store, *args, **kwargs):
|
||||
|
||||
# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
|
||||
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
|
||||
trace_debug: TracingDebugInfo | None,
|
||||
debug: core.DebugInfo | None,
|
||||
result_paths: tuple[str, ...] | None = None,
|
||||
) -> core.Jaxpr:
|
||||
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
|
||||
if trace_debug is None:
|
||||
if debug is None:
|
||||
return jaxpr
|
||||
# TODO(necula): re-enable this safety check
|
||||
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
|
||||
if result_paths is None:
|
||||
result_paths = trace_debug.result_paths_thunk() # type: ignore
|
||||
debug_info = core.JaxprDebugInfo(
|
||||
trace_debug.traced_for, trace_debug.func_src_info,
|
||||
trace_debug.arg_names, tuple(result_paths)) # type: ignore
|
||||
return jaxpr.replace(debug_info=debug_info)
|
||||
if result_paths is not None:
|
||||
debug = debug._replace(result_paths=tuple(result_paths))
|
||||
else:
|
||||
debug = debug.resolve_result_paths()
|
||||
return jaxpr.replace(debug_info=debug)
|
||||
|
||||
def hoist_obj_attrs(f, flat_args):
|
||||
idxs, objs, flat_args_ = [], [], []
|
||||
|
@ -1202,7 +1202,7 @@ def checkify(f: Callable[..., Out],
|
||||
in_tree = jtu.tree_structure(((), {}))
|
||||
closed_f = lambda: f(*args, **kwargs)
|
||||
# stage:
|
||||
debug = api_util.tracing_debug_info("checkify", f, args, kwargs)
|
||||
debug = api_util.debug_info("checkify", f, args, kwargs)
|
||||
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
|
||||
debug_info=debug),
|
||||
in_tree)
|
||||
|
@ -82,31 +82,7 @@ EffectTypeSet = effects.EffectTypeSet
|
||||
no_effects: Effects = effects.no_effects
|
||||
|
||||
|
||||
# TODO(necula): make this an extension of TracingDebugInfo
|
||||
class JaxprDebugInfo(NamedTuple):
|
||||
# An extension of lu.TracingDebugInfo; see comments there
|
||||
traced_for: str
|
||||
func_src_info: str
|
||||
arg_names: tuple[str | None, ...]
|
||||
# This is formed after tracing, when we have concrete `result_paths`
|
||||
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
|
||||
|
||||
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the arg_names with a safety check."""
|
||||
if len(self.arg_names) == expected:
|
||||
return self.arg_names
|
||||
else:
|
||||
# TODO(necula): this should not happen
|
||||
return (None,) * expected
|
||||
|
||||
def safe_result_paths(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the result_paths with a safety check."""
|
||||
if len(self.result_paths) == expected:
|
||||
return self.result_paths
|
||||
else:
|
||||
# TODO(necula): this should not happen
|
||||
return ("",) * expected
|
||||
|
||||
DebugInfo = lu.DebugInfo
|
||||
|
||||
class Jaxpr:
|
||||
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
|
||||
@ -117,7 +93,7 @@ class Jaxpr:
|
||||
_outvars: list[Atom]
|
||||
_eqns: list[JaxprEqn]
|
||||
_effects: Effects
|
||||
_debug_info: JaxprDebugInfo | None
|
||||
_debug_info: DebugInfo | None
|
||||
|
||||
@property
|
||||
def constvars(self) -> list[Var]:
|
||||
@ -140,13 +116,13 @@ class Jaxpr:
|
||||
return self._effects
|
||||
|
||||
@property
|
||||
def debug_info(self) -> JaxprDebugInfo | None:
|
||||
def debug_info(self) -> DebugInfo | None:
|
||||
return self._debug_info
|
||||
|
||||
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
|
||||
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
|
||||
effects: Effects = no_effects,
|
||||
debug_info: JaxprDebugInfo | None = None):
|
||||
debug_info: DebugInfo | None = None):
|
||||
"""
|
||||
Args:
|
||||
constvars: list of variables introduced for constants. Array constants are
|
||||
@ -157,14 +133,14 @@ class Jaxpr:
|
||||
eqns: list of equations.
|
||||
effects: set of effects. The effects on a jaxpr are a superset of the
|
||||
union of the effects for each equation.
|
||||
debug_info: optional JaxprDebugInfo.
|
||||
debug_info: optional DebugInfo.
|
||||
"""
|
||||
self._constvars = list(constvars)
|
||||
self._invars = list(invars)
|
||||
self._outvars = list(outvars)
|
||||
self._eqns = list(eqns)
|
||||
self._effects = effects
|
||||
self._debug_info = debug_info
|
||||
self._debug_info = debug_info and debug_info.resolve_result_paths()
|
||||
# TODO(necula): re-enable these safety checks
|
||||
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
|
||||
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
|
||||
|
@ -147,7 +147,7 @@ class custom_vmap:
|
||||
raise AttributeError(
|
||||
f"No batching rule defined for custom_vmap function {fun_name} "
|
||||
"using def_vmap.")
|
||||
debug = api_util.tracing_debug_info("custom_vmap", self.fun, args, {})
|
||||
debug = api_util.debug_info("custom_vmap", self.fun, args, {})
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(self.fun, debug_info=debug),
|
||||
|
@ -127,12 +127,12 @@ class custom_dce:
|
||||
"def_dce."
|
||||
)
|
||||
rule_name = util.fun_name(self.dce_rule)
|
||||
debug = api_util.tracing_debug_info("custom_dce", self.fun,
|
||||
args, {},
|
||||
static_argnums=self.static_argnums)
|
||||
debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule,
|
||||
args, {},
|
||||
static_argnums=self.static_argnums)
|
||||
debug = api_util.debug_info("custom_dce", self.fun,
|
||||
args, {},
|
||||
static_argnums=self.static_argnums)
|
||||
debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule,
|
||||
args, {},
|
||||
static_argnums=self.static_argnums)
|
||||
args = api_util.resolve_kwargs(self.fun, args, kwargs)
|
||||
if self.static_argnums:
|
||||
static_argnums = set(self.static_argnums)
|
||||
|
@ -468,9 +468,9 @@ class custom_partitioning:
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
args = _resolve_kwargs(self.fun, args, kwargs)
|
||||
debug = api_util.tracing_debug_info("custom_partitioning", self.fun,
|
||||
args, kwargs,
|
||||
static_argnums=self.static_argnums)
|
||||
debug = api_util.debug_info("custom_partitioning", self.fun,
|
||||
args, kwargs,
|
||||
static_argnums=self.static_argnums)
|
||||
if self.static_argnums:
|
||||
static_argnums = set(self.static_argnums)
|
||||
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
|
||||
|
@ -147,7 +147,7 @@ def _linearize_jaxpr(
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
nonzeros: tuple[bool, ...]
|
||||
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]:
|
||||
dbg = lu.TracingDebugInfo.from_jaxpr(jaxpr)
|
||||
dbg = jaxpr.jaxpr.debug_info
|
||||
primal_trace = pe.DynamicJaxprTrace(dbg)
|
||||
tangent_trace = pe.DynamicJaxprTrace(dbg)
|
||||
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
|
||||
|
@ -42,7 +42,6 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent, JaxprEqnContext)
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src import tree_util
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
|
||||
tree_flatten, tree_structure)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
@ -932,7 +931,7 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
|
||||
in_unknowns: Sequence[bool],
|
||||
instantiate: bool | Sequence[bool]):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
|
||||
debug_info=lu.TracingDebugInfo.from_jaxpr(jaxpr))
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
|
||||
cell = []
|
||||
def fun(*known_vals_in):
|
||||
@ -1334,10 +1333,11 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
|
||||
|
||||
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
|
||||
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
|
||||
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
|
||||
dbg = jaxpr.debug_info and core.DebugInfo(
|
||||
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||
jaxpr.debug_info.arg_names,
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
|
||||
used_outputs) if b))
|
||||
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
|
||||
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||
return new_jaxpr
|
||||
@ -1422,10 +1422,12 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
|
||||
eqns = new_eqns[::-1]
|
||||
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
|
||||
|
||||
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
|
||||
dbg = jaxpr.debug_info and core.DebugInfo(
|
||||
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.safe_arg_names(len(used_inputs)),
|
||||
used_inputs) if b),
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.safe_result_paths(len(used_outputs)),
|
||||
used_outputs) if b))
|
||||
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
|
||||
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||
|
||||
@ -1623,9 +1625,9 @@ class JaxprStackFrame:
|
||||
attrs_tracked: list[tuple[Any, str]]
|
||||
attrs_inits: list
|
||||
attrs_vars: list[Var]
|
||||
debug_info: lu.TracingDebugInfo | None
|
||||
debug_info: core.DebugInfo | None
|
||||
|
||||
def __init__(self, debug_info: lu.TracingDebugInfo | None):
|
||||
def __init__(self, debug_info: core.DebugInfo | None):
|
||||
self.gensym = core.gensym()
|
||||
self.tracer_to_var = {}
|
||||
self.constid_to_tracer = {}
|
||||
@ -1809,7 +1811,7 @@ def _inline_literals(
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = ("frame",)
|
||||
|
||||
def __init__(self, debug_info: lu.TracingDebugInfo | None):
|
||||
def __init__(self, debug_info: core.DebugInfo | None):
|
||||
self.frame = JaxprStackFrame(debug_info)
|
||||
|
||||
def invalidate(self):
|
||||
@ -2114,7 +2116,7 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
|
||||
def trace_to_jaxpr_dynamic(
|
||||
fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: lu.TracingDebugInfo | None = None,
|
||||
debug_info: core.DebugInfo | None = None,
|
||||
*,
|
||||
keep_inputs: list[bool] | None = None,
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
|
||||
@ -2137,7 +2139,7 @@ def trace_to_jaxpr_dynamic(
|
||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
|
||||
|
||||
def _check_no_returned_refs(
|
||||
dbg: lu.TracingDebugInfo | None,
|
||||
dbg: core.DebugInfo | None,
|
||||
out_tracers: Sequence[DynamicJaxprTracer]
|
||||
) -> None:
|
||||
if not config.mutable_array_checks.value: return
|
||||
@ -2148,10 +2150,8 @@ def _check_no_returned_refs(
|
||||
raise ValueError(
|
||||
f"function returned a mutable array reference of type {a.str_short()}, "
|
||||
"but mutable array references cannot be returned.")
|
||||
loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore
|
||||
if (dbg.result_paths_thunk and
|
||||
(ls := dbg.result_paths_thunk()) and
|
||||
ls[i]) else '')
|
||||
result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
|
||||
loc = f' at output tree path {result_paths[i]}'
|
||||
frame = t._trace.frame
|
||||
v = frame.tracer_to_var.get(id(t))
|
||||
eqn = next((e for e in frame.eqns if v in e.outvars), None)
|
||||
@ -2172,7 +2172,7 @@ def _check_no_returned_refs(
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic2(
|
||||
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None
|
||||
fun: lu.WrappedFun, debug_info: core.DebugInfo | None = None
|
||||
) -> tuple[Jaxpr, OutputType, list[Any]]:
|
||||
|
||||
trace = DynamicJaxprTrace(debug_info)
|
||||
|
@ -652,7 +652,7 @@ class ParallelCallableInfo:
|
||||
in_axes: Iterable[int | None]
|
||||
out_axes_thunk: Callable[[], Sequence[int | None]]
|
||||
avals: Sequence[core.AbstractValue]
|
||||
debug_info: api_util.TracingDebugInfo | None
|
||||
debug_info: core.DebugInfo | None
|
||||
|
||||
@cached_property
|
||||
def local_devices(self):
|
||||
@ -964,7 +964,7 @@ class UnloadedPmapExecutable:
|
||||
ordered_effects: list[core.Effect]
|
||||
keepalive: Sequence[Any]
|
||||
host_callbacks: Sequence[Any]
|
||||
jaxpr_debug_info: core.JaxprDebugInfo
|
||||
jaxpr_debug_info: core.DebugInfo
|
||||
|
||||
def build_execute_fun(self):
|
||||
input_indices = []
|
||||
@ -1004,7 +1004,7 @@ class UnloadedPmapExecutable:
|
||||
ordered_effects: list[core.Effect],
|
||||
host_callbacks: list[Any],
|
||||
keepalive: Any,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo,
|
||||
jaxpr_debug_info: core.DebugInfo,
|
||||
platforms: Sequence[str],
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
compiler_options=None):
|
||||
@ -2127,7 +2127,7 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
|
||||
class AllArgsInfo(NamedTuple):
|
||||
"""Avals and debug_info for all arguments prior to DCE."""
|
||||
in_avals: Sequence[core.ShapedArray]
|
||||
debug_info: core.JaxprDebugInfo | None
|
||||
debug_info: core.DebugInfo | None
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
@ -3199,7 +3199,7 @@ def cc_shard_arg(x, sharding, layout):
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
|
||||
jaxpr_debug_info: core.DebugInfo | None = None):
|
||||
if len(ref_avals) != len(arg_avals):
|
||||
raise TypeError(
|
||||
f"Computation compiled for {len(ref_avals)} inputs "
|
||||
@ -3258,7 +3258,7 @@ def check_array_xla_sharding_layout_match(
|
||||
args_after_dce,
|
||||
in_xla_shardings: Sequence[JSharding],
|
||||
in_xla_layouts: Sequence[DeviceLocalLayout],
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None,
|
||||
jaxpr_debug_info: core.DebugInfo | None,
|
||||
kept_var_idx: set[int]) -> None:
|
||||
from jax._src.array import ArrayImpl
|
||||
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
|
||||
|
@ -53,7 +53,7 @@ def _typecheck_param(prim, param, name, msg_required, pred):
|
||||
def _initial_style_open_jaxpr(fun: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug_info: api_util.TracingDebugInfo):
|
||||
debug_info: core.DebugInfo):
|
||||
wrapped_fun, out_tree = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun, debug_info=debug_info),
|
||||
in_tree)
|
||||
@ -65,7 +65,7 @@ def _initial_style_open_jaxpr(fun: Callable,
|
||||
def _initial_style_jaxpr(fun: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug_info: api_util.TracingDebugInfo):
|
||||
debug_info: core.DebugInfo):
|
||||
jaxpr, consts, out_tree, () = _initial_style_open_jaxpr(
|
||||
fun, in_tree, in_avals, debug_info)
|
||||
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
||||
@ -74,7 +74,7 @@ def _initial_style_jaxpr(fun: Callable,
|
||||
def _initial_style_jaxpr_attrs(fun: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug_info: api_util.TracingDebugInfo):
|
||||
debug_info: core.DebugInfo):
|
||||
jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr(
|
||||
fun, in_tree, in_avals, debug_info)
|
||||
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
||||
@ -83,7 +83,7 @@ def _initial_style_jaxpr_attrs(fun: Callable,
|
||||
def _initial_style_jaxprs_with_common_consts(
|
||||
funs: Sequence[Callable],
|
||||
in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
|
||||
debug_infos: Sequence[api_util.TracingDebugInfo]):
|
||||
debug_infos: Sequence[core.DebugInfo]):
|
||||
# When staging the branches of a conditional into jaxprs, constants are
|
||||
# extracted from each branch and converted to jaxpr arguments. To use the
|
||||
# staged jaxprs as the branches to a conditional *primitive*, we need for
|
||||
|
@ -134,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
if (config.disable_jit.value and core.is_concrete(index)):
|
||||
return branches[int(index)](*operands)
|
||||
|
||||
dbgs = [api_util.tracing_debug_info("switch", branch, operands, {})
|
||||
dbgs = [api_util.debug_info("switch", branch, operands, {})
|
||||
for branch in branches]
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
@ -237,10 +237,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
dbg_true_fun = api_util.tracing_debug_info("cond", true_fun, operands, {})
|
||||
dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {})
|
||||
if config.mutable_array_checks.value:
|
||||
api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops)
|
||||
dbg_false_fun = api_util.tracing_debug_info("cond", false_fun, operands, {})
|
||||
dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {})
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals,
|
||||
[dbg_true_fun, dbg_false_fun])
|
||||
|
@ -195,7 +195,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
def _create_jaxpr(init):
|
||||
init_flat = tree_leaves(init)
|
||||
_, in_tree = tree_flatten((init, xs))
|
||||
dbg = api_util.tracing_debug_info("scan", f, (init, xs), {})
|
||||
dbg = api_util.debug_info("scan", f, (init, xs), {})
|
||||
carry_avals = tuple(map(core.get_aval, init_flat))
|
||||
jaxpr, _, out_tree = _initial_style_jaxpr(
|
||||
f, in_tree, carry_avals + x_avals, dbg)
|
||||
|
@ -273,7 +273,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
return carry, stacked_y
|
||||
|
||||
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
||||
dbg_body = api_util.tracing_debug_info("scan", f, (init, xs), {})
|
||||
dbg_body = api_util.debug_info("scan", f, (init, xs), {})
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
@ -1357,10 +1357,10 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
def _create_jaxpr(init_val):
|
||||
init_vals, in_tree = tree_flatten((init_val,))
|
||||
init_avals = tuple(_map(core.get_aval, init_vals))
|
||||
cond_dbg = api_util.tracing_debug_info("while_cond", cond_fun, (init_val,), {})
|
||||
cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
|
||||
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
|
||||
cond_fun, in_tree, init_avals, cond_dbg)
|
||||
body_dbg = api_util.tracing_debug_info("while_body", body_fun, (init_val,), {})
|
||||
body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
|
||||
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
|
||||
body_fun, in_tree, init_avals, body_dbg)
|
||||
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
|
||||
|
@ -93,16 +93,16 @@ def custom_root(f: Callable,
|
||||
"""
|
||||
guess_flat, in_args_tree = tree_flatten((initial_guess,))
|
||||
guess_avals = tuple(_map(core.get_aval, guess_flat))
|
||||
f_debug = api_util.tracing_debug_info("custom_root", f, (initial_guess,), {})
|
||||
f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {})
|
||||
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
|
||||
f, in_args_tree, guess_avals, f_debug)
|
||||
|
||||
in_tree, = treedef_children(in_args_tree)
|
||||
_check_tree("f", "initial_guess", out_tree, in_tree, False)
|
||||
|
||||
solve_debug = api_util.tracing_debug_info("custom_root solve", solve,
|
||||
(f, initial_guess), {},
|
||||
static_argnums=(0,))
|
||||
solve_debug = api_util.debug_info("custom_root solve", solve,
|
||||
(f, initial_guess), {},
|
||||
static_argnums=(0,))
|
||||
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
|
||||
partial(solve, f), in_args_tree, guess_avals, solve_debug)
|
||||
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
|
||||
@ -111,10 +111,10 @@ def custom_root(f: Callable,
|
||||
unchecked_zeros, f_jvp = api.linearize(f, x)
|
||||
return tangent_solve(f_jvp, b)
|
||||
|
||||
tangent_solve_debug = api_util.tracing_debug_info("custom_root tangent_solve",
|
||||
tangent_solve,
|
||||
(f, initial_guess), {},
|
||||
static_argnums=(0,))
|
||||
tangent_solve_debug = api_util.debug_info("custom_root tangent_solve",
|
||||
tangent_solve,
|
||||
(f, initial_guess), {},
|
||||
static_argnums=(0,))
|
||||
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
|
||||
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2,
|
||||
tangent_solve_debug)
|
||||
@ -265,17 +265,17 @@ def custom_linear_solve(
|
||||
|
||||
return f_aux if has_aux else f
|
||||
|
||||
matvec_debug = api_util.tracing_debug_info("custom_linear_solve",
|
||||
matvec, (b,), {})
|
||||
matvec_debug = api_util.debug_info("custom_linear_solve",
|
||||
matvec, (b,), {})
|
||||
# no auxiliary data assumed for matvec
|
||||
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
|
||||
matvec_debug)
|
||||
_check_tree("matvec", "b", out_tree, tree, False)
|
||||
|
||||
solve_debug = api_util.tracing_debug_info("custom_linear_solve solve",
|
||||
solve, (matvec, b), {},
|
||||
static_argnums=(0,))
|
||||
solve_debug = api_util.debug_info("custom_linear_solve solve",
|
||||
solve, (matvec, b), {},
|
||||
static_argnums=(0,))
|
||||
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
|
||||
solve_debug)
|
||||
@ -285,7 +285,7 @@ def custom_linear_solve(
|
||||
vecmat_jaxpr = tr_solve_jaxpr = None
|
||||
vecmat_consts = tr_solve_consts = []
|
||||
else:
|
||||
transpose_solve_debug = api_util.tracing_debug_info(
|
||||
transpose_solve_debug = api_util.debug_info(
|
||||
"custom_linear_solve transpose_solve", transpose_solve,
|
||||
(matvec, b), {}, static_argnums=(0,))
|
||||
if symmetric:
|
||||
|
@ -745,7 +745,7 @@ def _trace_composite_to_jaxpr(fun: Callable,
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
name: str,
|
||||
debug_info: api_util.TracingDebugInfo):
|
||||
debug_info: core.DebugInfo):
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
@ -822,8 +822,8 @@ def composite(
|
||||
"""
|
||||
@functools.wraps(decomposition)
|
||||
def _decorator(*args, **kwargs):
|
||||
debug_info = api_util.tracing_debug_info("composite", decomposition,
|
||||
args, kwargs)
|
||||
debug_info = api_util.debug_info("composite", decomposition,
|
||||
args, kwargs)
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
in_avals = tuple(core.get_aval(x) for x in flat_args)
|
||||
closed_jaxpr, out_tree = _trace_composite_to_jaxpr(
|
||||
|
@ -156,7 +156,7 @@ class WrappedFun:
|
||||
f_transformed: Callable,
|
||||
transforms,
|
||||
stores: tuple[Store | EqualStore | None, ...], params, in_type,
|
||||
debug_info: TracingDebugInfo | None):
|
||||
debug_info: DebugInfo | None):
|
||||
self.f = f
|
||||
self.f_transformed = f_transformed
|
||||
self.transforms = transforms
|
||||
@ -253,12 +253,10 @@ def fun_name(f):
|
||||
except:
|
||||
return str(f)
|
||||
|
||||
class TracingDebugInfo(NamedTuple):
|
||||
"""Tracing-time debugging info about a func and its arguments.
|
||||
|
||||
Formed just before staging to a jaxpr and read in trace-time error messages.
|
||||
"""
|
||||
class DebugInfo(NamedTuple):
|
||||
"""Debugging info about a func, its arguments, and results."""
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
|
||||
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
|
||||
# no source location information. The first word is always the function name,
|
||||
# which may be '<unknown>'.
|
||||
@ -270,23 +268,25 @@ class TracingDebugInfo(NamedTuple):
|
||||
# e.g., tangent args in jax.jvp.
|
||||
arg_names: tuple[str | None, ...]
|
||||
|
||||
# The result paths are not available while we are tracing the function,
|
||||
# instead we keep a thunk. Once we are done tracing, we use
|
||||
# `self.resolve_result_paths()` to execute the thunk and replace the
|
||||
# actual result paths.
|
||||
# e.g. ('[0]', '[1]', ...)
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]] | None
|
||||
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
|
||||
|
||||
@classmethod
|
||||
def from_jaxpr(cls, jaxpr: core.ClosedJaxpr) -> TracingDebugInfo | None:
|
||||
jaxpr_dbg = jaxpr.jaxpr._debug_info
|
||||
if jaxpr_dbg is None: return None
|
||||
return TracingDebugInfo(jaxpr_dbg.traced_for,
|
||||
jaxpr_dbg.func_src_info,
|
||||
jaxpr_dbg.arg_names,
|
||||
lambda: jaxpr_dbg.result_paths)
|
||||
def add_result_paths(self,
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]]
|
||||
) -> DebugInfo:
|
||||
assert self.result_paths is None
|
||||
return self._replace(result_paths=HashableFunction(result_paths_thunk,
|
||||
closure=()))
|
||||
|
||||
def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]]
|
||||
) -> TracingDebugInfo:
|
||||
assert self.result_paths_thunk is None
|
||||
return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk,
|
||||
closure=()))
|
||||
def resolve_result_paths(self) -> DebugInfo:
|
||||
"""Return a debug info with resolved result paths."""
|
||||
if callable(self.result_paths):
|
||||
return self._replace(result_paths=tuple(self.result_paths()))
|
||||
return self
|
||||
|
||||
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the arg_names with a safety check."""
|
||||
@ -296,9 +296,18 @@ class TracingDebugInfo(NamedTuple):
|
||||
# TODO(necula): this should not happen
|
||||
return (None,) * expected
|
||||
|
||||
def safe_result_paths(self, expected: int) -> tuple[str, ...]:
|
||||
"""Get the result paths with a safety check."""
|
||||
assert not callable(self.result_paths), self
|
||||
if self.result_paths is not None and len(self.result_paths) == expected:
|
||||
return self.result_paths
|
||||
else:
|
||||
# TODO(necula): this should not happen
|
||||
return ("",) * expected
|
||||
|
||||
|
||||
def wrap_init(f: Callable, params=None, *,
|
||||
debug_info: TracingDebugInfo | None = None) -> WrappedFun:
|
||||
debug_info: DebugInfo | None = None) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
params_dict = {} if params is None else params
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
@ -341,7 +350,7 @@ def _check_input_type(in_type: core.InputType) -> None:
|
||||
provided[d.val] = True
|
||||
assert all(provided)
|
||||
|
||||
def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None
|
||||
def add_debug_info(f: WrappedFun, debug_info: DebugInfo | None
|
||||
) -> WrappedFun:
|
||||
"""Produce a new WrappedFun with debug_info attached."""
|
||||
assert f.debug_info is None
|
||||
|
@ -413,9 +413,9 @@ class BlockSpec:
|
||||
|
||||
fake_index_map_args, fake_index_map_kwargs = \
|
||||
index_map_tree.unflatten([False] * index_map_tree.num_leaves)
|
||||
debug = api_util.tracing_debug_info("pallas_call index_map",
|
||||
index_map_func, fake_index_map_args,
|
||||
fake_index_map_kwargs)
|
||||
debug = api_util.debug_info("pallas_call index_map",
|
||||
index_map_func, fake_index_map_args,
|
||||
fake_index_map_kwargs)
|
||||
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
|
||||
index_map_src_info = NameAndSrcInfo.from_pallas_call(
|
||||
|
@ -1100,8 +1100,8 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
|
||||
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
|
||||
debug = api_util.tracing_debug_info("checkify_pallas", checked_kernel_fn,
|
||||
retrace_in_avals, {})
|
||||
debug = api_util.debug_info("checkify_pallas", checked_kernel_fn,
|
||||
retrace_in_avals, {})
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
|
||||
@ -1167,7 +1167,7 @@ def _trace_kernel_to_jaxpr(
|
||||
wrapped_kernel_fun, kernel_in_transforms
|
||||
)
|
||||
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
|
||||
debug = api_util.tracing_debug_info("pallas_call", fun, fake_kernel_args, {})
|
||||
debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||
kernel_avals, debug)
|
||||
@ -1568,7 +1568,7 @@ def pallas_call(
|
||||
kernel_fun_sig = api_util.fun_signature(kernel)
|
||||
arg_names = None
|
||||
if kernel_fun_sig:
|
||||
kernel_debug_info = api_util.tracing_debug_info(
|
||||
kernel_debug_info = api_util.debug_info(
|
||||
"pallas_call kernel",
|
||||
kernel,
|
||||
[1] * len(kernel_fun_sig.parameters), {})
|
||||
|
@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, check_callable, resolve_argnums,
|
||||
argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info,
|
||||
argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info,
|
||||
hoist_obj_attrs, _check_no_aliased_ref_args,
|
||||
_check_no_aliased_closed_over_refs)
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -548,7 +548,7 @@ def _infer_params_impl(
|
||||
ji: PjitInfo,
|
||||
pjit_mesh: mesh_lib.Mesh | None,
|
||||
resource_env: mesh_lib.ResourceEnv | None,
|
||||
dbg: lu.TracingDebugInfo,
|
||||
dbg: core.DebugInfo,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
in_avals: tuple[core.AbstractValue, ...] | None,
|
||||
@ -733,7 +733,7 @@ def _infer_params(
|
||||
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
|
||||
' together is not allowed.')
|
||||
|
||||
dbg = tracing_debug_info(
|
||||
dbg = debug_info(
|
||||
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
|
||||
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
|
||||
signature=ji.fun_signature)
|
||||
@ -756,7 +756,7 @@ def _infer_params(
|
||||
entry.pjit_params = p
|
||||
return entry.pjit_params, entry.pjit_params.consts + dynargs
|
||||
|
||||
def _infer_input_type(fun: Callable, dbg: lu.TracingDebugInfo | None,
|
||||
def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None,
|
||||
explicit_args) -> tuple[core.AbstractValue, ...]:
|
||||
avals = []
|
||||
try:
|
||||
@ -1302,7 +1302,7 @@ def _create_pjit_jaxpr(
|
||||
fun: lu.WrappedFun,
|
||||
in_type: core.InputType | Sequence[core.AbstractValue],
|
||||
attr_data: int,
|
||||
debug_info: lu.TracingDebugInfo,
|
||||
debug_info: core.DebugInfo,
|
||||
result_paths: Callable,
|
||||
ignored_inline: IgnoreKey
|
||||
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
|
||||
@ -1346,7 +1346,7 @@ def _create_pjit_jaxpr(
|
||||
def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
|
||||
out_layouts_leaves, out_tree, out_avals,
|
||||
debug_info: core.JaxprDebugInfo | None,
|
||||
debug_info: core.DebugInfo | None,
|
||||
device_or_backend_set):
|
||||
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
|
||||
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):
|
||||
|
@ -989,7 +989,7 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
|
||||
|
||||
def initial_style_jaxpr(
|
||||
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
|
||||
dbg: api_util.TracingDebugInfo,
|
||||
dbg: core.DebugInfo,
|
||||
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
||||
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg)
|
||||
|
||||
@ -997,7 +997,7 @@ def initial_style_jaxpr(
|
||||
def _initial_style_jaxpr(fun: Callable,
|
||||
in_tree: api_util.PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug: api_util.TracingDebugInfo):
|
||||
debug: core.DebugInfo):
|
||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
||||
tree_util.treedef_tuple((in_tree,)))
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
@ -1007,7 +1007,7 @@ def _initial_style_jaxpr(fun: Callable,
|
||||
T = TypeVar('T')
|
||||
def run_state(f: Callable[..., None]) -> Callable[[T], T]:
|
||||
def wrapped(args):
|
||||
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
|
||||
dbg = api_util.debug_info("run_state", f, (args,), {})
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
|
||||
# There may be some uninitialized values here in ref_args.
|
||||
@ -1027,7 +1027,7 @@ def run_state(f: Callable[..., None]) -> Callable[[T], T]:
|
||||
|
||||
def run_state_reference(f: Callable[..., None]):
|
||||
def wrapped(args):
|
||||
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
|
||||
dbg = api_util.debug_info("run_state", f, (args,), {})
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
|
||||
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg)
|
||||
|
@ -20,13 +20,13 @@ from jax._src.core import (
|
||||
AbstractValue as AbstractValue,
|
||||
Atom as Atom,
|
||||
CallPrimitive as CallPrimitive,
|
||||
DebugInfo as DebugInfo,
|
||||
DShapedArray as DShapedArray,
|
||||
DropVar as DropVar,
|
||||
Effect as Effect,
|
||||
Effects as Effects,
|
||||
get_opaque_trace_state as get_opaque_trace_state,
|
||||
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
|
||||
JaxprDebugInfo as JaxprDebugInfo,
|
||||
JaxprPpContext as JaxprPpContext,
|
||||
JaxprPpSettings as JaxprPpSettings,
|
||||
JaxprTypeError as JaxprTypeError,
|
||||
|
@ -69,13 +69,13 @@ def _collect_jaxprs(jaxpr: core.Jaxpr,
|
||||
return acc
|
||||
|
||||
|
||||
def _debug_info_to_string(dbg: api_util.TracingDebugInfo | core.JaxprDebugInfo | None) -> list[str]:
|
||||
def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]:
|
||||
if dbg is None: return "None"
|
||||
# Strip the absolute path and the line number but check that it references
|
||||
# this file (to catch errors when the source info points in JAX internals)
|
||||
fun_src_info = re.sub(r"^(\S+)( at .*/debug_info_test.py:.*)?", "\\1", dbg.func_src_info)
|
||||
res = f"traced_for={dbg.traced_for}, fun={fun_src_info}, arg_names={','.join(dbg.arg_names)}"
|
||||
if isinstance(dbg, core.JaxprDebugInfo):
|
||||
if isinstance(dbg.result_paths, tuple):
|
||||
if dbg.result_paths:
|
||||
res += f", result_paths={','.join(dbg.result_paths)}"
|
||||
else:
|
||||
@ -221,24 +221,24 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
def my_f(x, y, z, w):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
|
||||
dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
|
||||
self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+")
|
||||
self.assertEqual(dbg.arg_names, ("x", "y", "z", "w"))
|
||||
self.assertIsNone(dbg.result_paths_thunk)
|
||||
self.assertIsNone(dbg.result_paths)
|
||||
|
||||
def test_debug_info_arg_passed_as_kwarg(self):
|
||||
def my_f(x, y, z):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3))
|
||||
dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3))
|
||||
self.assertEqual(dbg.arg_names, ("x", "y", "z"))
|
||||
|
||||
def test_debug_info_pytrees(self):
|
||||
def my_f(x_tree, *, y_tree):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),),
|
||||
dict(y_tree=dict(z=3, w=4)))
|
||||
dbg = api_util.debug_info("jit", my_f, ((1, 2),),
|
||||
dict(y_tree=dict(z=3, w=4)))
|
||||
self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]",
|
||||
"y_tree['w']", "y_tree['z']"))
|
||||
|
||||
@ -246,43 +246,43 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
def my_f(x, y, *, z, w):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
|
||||
static_argnums=(1,),
|
||||
static_argnames=("w",))
|
||||
dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
|
||||
static_argnums=(1,),
|
||||
static_argnames=("w",))
|
||||
self.assertEqual(dbg.arg_names, ("x", "z"))
|
||||
|
||||
def test_debug_info_with_pytrees_and_statics(self):
|
||||
def my_f(x, y, *, z, w):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)),
|
||||
dict(z=(3, 4), w=(5, 6)),
|
||||
static_argnums=(1,),
|
||||
static_argnames=("w",))
|
||||
dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)),
|
||||
dict(z=(3, 4), w=(5, 6)),
|
||||
static_argnums=(1,),
|
||||
static_argnames=("w",))
|
||||
self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]"))
|
||||
|
||||
def test_debug_info_too_many_args(self):
|
||||
def my_f(x):
|
||||
pass
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3))
|
||||
dbg = api_util.debug_info("jit", my_f, (1, 2, 3), dict(z=3))
|
||||
self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']"))
|
||||
|
||||
def test_debug_info_no_source_info_built_in(self):
|
||||
# built-in function "int" does not have an inspect.Signature
|
||||
dbg = api_util.tracing_debug_info("jit", max, (1,), {})
|
||||
dbg = api_util.debug_info("jit", max, (1,), {})
|
||||
self.assertEqual(dbg.func_src_info, "max")
|
||||
self.assertEqual(dbg.arg_names, ("args[0]",))
|
||||
|
||||
def test_debug_info_lambda(self):
|
||||
# built-in function "int" does not have an inspect.Signature
|
||||
dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {})
|
||||
dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {})
|
||||
self.assertRegex(dbg.func_src_info, r"^<lambda> at .*debug_info_test.py:\d+")
|
||||
self.assertEqual(dbg.arg_names, ("my_arg",))
|
||||
|
||||
def test_debug_info_no_source_info_not_callable(self):
|
||||
# built-in function "int" does not have an inspect.Signature
|
||||
dbg = api_util.tracing_debug_info("jit", False, (1,), {})
|
||||
dbg = api_util.debug_info("jit", False, (1,), {})
|
||||
self.assertEqual(dbg.func_src_info, "<unknown>")
|
||||
self.assertEqual(dbg.arg_names, ("args[0]",))
|
||||
|
||||
@ -293,7 +293,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
def __call__(self, y):
|
||||
return self.x + y
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
|
||||
dbg = api_util.debug_info("jit", Foo(), (1,), {})
|
||||
self.assertRegex(dbg.func_src_info, "<unknown>")
|
||||
self.assertEqual(dbg.arg_names, ("y",))
|
||||
|
||||
@ -307,7 +307,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
def __repr__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
|
||||
dbg = api_util.debug_info("jit", Foo(), (1,), {})
|
||||
self.assertRegex(dbg.func_src_info, "<unknown>")
|
||||
self.assertEqual(dbg.arg_names, ("y",))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user