mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[better_errors] Improvements in propagation of debugging info
Added some documentation for `TracingDebugInfo` (docstring, comments about `arg_names`, since it was not obvious to me that this would flatten the non-static arguments). Laying the ground for the unification of the old `api_util.debug_info` and `partial_eval.tracing_debug_info`: we rename the former to `api_util.tracing_debug_info`, we push inside the calls to `fun_sourceinfo` and `fun_signature` (which were done by the callers until now), and we rewrite the latter in terms of the former. We leave for a future PR the actual replacing of the latter with the former throughout. In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info` received None for the `in_tree` and `out_tracer_thunk`. The function contained catch-all exception clauses to handle those, but doing so it masked other places where we fail to collect debug info due to programming mistakes. E.g., in one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info. Added more type declarations. Added a `state_test` with a failure to track debugging information, manifested with a leaked tracer without function provenance. Fixing this in a subsequent PR.
This commit is contained in:
parent
d415c80b86
commit
dcf72b01f4
@ -32,8 +32,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import effects
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, debug_info, fun_sourceinfo, fun_signature)
|
||||
from jax._src import api_util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -42,7 +41,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import convolution as lax_convolution
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure
|
||||
from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten, tree_structure
|
||||
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
|
||||
safe_zip, merge_lists, weakref_lru_cache)
|
||||
|
||||
@ -324,8 +323,9 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def fun_remat(*args, **kwargs):
|
||||
debug = debug_info("checkpoint / remat", fun_sourceinfo(fun),
|
||||
fun_signature(fun), args, kwargs, static_argnums, ())
|
||||
debug = api_util.tracing_debug_info(
|
||||
"checkpoint / remat", fun,
|
||||
args, kwargs, static_argnums=static_argnums)
|
||||
fun_, args = _remat_static_argnums(fun, static_argnums, args)
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
in_avals = [core.shaped_abstractify(x) for x in args_flat]
|
||||
@ -415,8 +415,12 @@ _dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached)
|
||||
# This helper is similar to those in control_flow/common.py, but with
|
||||
# remat-specific errors.
|
||||
@weakref_lru_cache
|
||||
def _trace_to_jaxpr(fun, in_tree, in_avals, debug):
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
def _trace_to_jaxpr(fun: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug: lu.TracingDebugInfo
|
||||
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
|
||||
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
try:
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
except core.ConcretizationTypeError as e:
|
||||
@ -530,7 +534,8 @@ ad.primitive_jvps[remat_p] = remat_jvp
|
||||
|
||||
effects.remat_allowed_effects.add_type(lax_internal.InOutFeedEffect)
|
||||
|
||||
def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
|
||||
jaxpr: core.Jaxpr, **params):
|
||||
assert not jaxpr.constvars
|
||||
disallowed_effects = effects.remat_allowed_effects.filter_not_in(jaxpr.effects)
|
||||
if disallowed_effects:
|
||||
@ -567,7 +572,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
# set up unknown outputs with a recipe to call remat
|
||||
res_tracers = map(trace.new_instantiated_const, residuals)
|
||||
_, tracers_staged = partition_list(in_used_staged, tracers)
|
||||
in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged)
|
||||
in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged) # type: ignore
|
||||
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
|
||||
for x in jaxpr_unknown.outvars]
|
||||
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
|
||||
|
@ -42,7 +42,6 @@ from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
|
||||
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
|
||||
prefix_errors, generate_key_paths, tree_flatten_with_path)
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -61,8 +60,8 @@ from jax._src.api_util import (
|
||||
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
||||
flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
apply_flat_fun_nokwargs, check_callable, debug_info,
|
||||
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
|
||||
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
|
||||
result_paths, flat_out_axes, debug_info_final)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -456,10 +455,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0,
|
||||
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
|
||||
f"{max_argnum + 1} positional arguments to be passed by the caller, "
|
||||
f"but got only {len(args)} positional arguments.")
|
||||
fun_src_info = fun_sourceinfo(fun)
|
||||
fun_signature = api_util.fun_signature(fun)
|
||||
dbg = debug_info('value_and_grad', fun_src_info, fun_signature,
|
||||
args, kwargs, (), ())
|
||||
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs)
|
||||
|
||||
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
@ -1405,11 +1401,9 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
if in_devices is not None and len(in_devices) == 0:
|
||||
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
||||
|
||||
src = fun_sourceinfo(fun)
|
||||
signature = api_util.fun_signature(fun)
|
||||
|
||||
dbg = debug_info('pmap', src, signature, args, kwargs,
|
||||
static_broadcasted_tuple, ())
|
||||
dbg = tracing_debug_info(
|
||||
'pmap', fun, args, kwargs,
|
||||
static_argnums=static_broadcasted_tuple)
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
if static_broadcasted_tuple:
|
||||
|
@ -116,6 +116,7 @@ def flattened_fun_in_tree(
|
||||
(args, store, f is flatten_fun.args[0])
|
||||
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
|
||||
except ValueError:
|
||||
# When `fn` is not the result of flatten_fun or flatten_fun_nokwargs
|
||||
return None
|
||||
else:
|
||||
return in_tree, lambda: out_tree_store.val, has_kwargs # type: ignore[union-attr]
|
||||
@ -589,7 +590,7 @@ def _dtype(x):
|
||||
def api_hook(fun, tag: str):
|
||||
return fun
|
||||
|
||||
|
||||
# TODO(necula): replace usage with tracing_debug_info
|
||||
def debug_info(
|
||||
traced_for: str, fun_src_info: str | None,
|
||||
fun_signature: inspect.Signature | None,
|
||||
@ -598,12 +599,36 @@ def debug_info(
|
||||
static_argnames: tuple[str, ...]
|
||||
) -> TracingDebugInfo | None:
|
||||
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
|
||||
arg_names = _arg_names(fun_signature, args, kwargs, static_argnums,
|
||||
arg_names = _non_static_arg_names(fun_signature, args, kwargs, static_argnums,
|
||||
static_argnames)
|
||||
if arg_names is None:
|
||||
return None
|
||||
return TracingDebugInfo(traced_for, fun_src_info, arg_names, None)
|
||||
|
||||
|
||||
def tracing_debug_info(
|
||||
traced_for: str,
|
||||
fun: Callable,
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
static_argnums: tuple[int, ...] = (),
|
||||
static_argnames: tuple[str, ...] = (),
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
|
||||
# TODO(necula): check if we really need this, e.g., to speed up tracing.
|
||||
sourceinfo: str | None = None,
|
||||
signature: inspect.Signature | None = None,
|
||||
) -> TracingDebugInfo:
|
||||
if sourceinfo is None:
|
||||
sourceinfo = fun_sourceinfo(fun)
|
||||
if signature is None:
|
||||
signature = fun_signature(fun)
|
||||
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
|
||||
static_argnames)
|
||||
# TODO(necula): remove type: ignore once we fix arg_names to never be None
|
||||
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) # type: ignore
|
||||
|
||||
|
||||
def fun_signature(fun: Callable) -> inspect.Signature | None:
|
||||
try:
|
||||
return inspect.signature(fun)
|
||||
@ -631,8 +656,11 @@ def fun_sourceinfo(fun: Callable) -> str | None:
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
|
||||
) -> tuple[str, ...] | None:
|
||||
def _non_static_arg_names(fn_signature: inspect.Signature | None,
|
||||
args: Sequence[Any], kwargs: dict[str, Any],
|
||||
static_argnums: Sequence[int],
|
||||
static_argnames: Sequence[str],
|
||||
) -> tuple[str | None, ...] | None:
|
||||
if fn_signature is None: return None
|
||||
static = object()
|
||||
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
|
||||
@ -665,7 +693,7 @@ def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
|
||||
result_paths = trace_debug.result_paths_thunk() # type: ignore
|
||||
debug_info = core.JaxprDebugInfo(
|
||||
trace_debug.traced_for, trace_debug.func_src_info,
|
||||
trace_debug.arg_names, tuple(result_paths))
|
||||
trace_debug.arg_names, tuple(result_paths)) # type: ignore
|
||||
return jaxpr.replace(debug_info=debug_info)
|
||||
|
||||
def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
|
||||
|
@ -140,8 +140,8 @@ class Jaxpr:
|
||||
self._eqns = list(eqns)
|
||||
self._effects = effects
|
||||
self._debug_info = debug_info
|
||||
assert (not debug_info or len(debug_info.arg_names) == len(invars) and
|
||||
len(debug_info.result_paths) == len(outvars))
|
||||
assert (not debug_info or debug_info.arg_names is None or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
|
||||
assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.pretty_print())
|
||||
|
@ -30,8 +30,8 @@ from jax._src import traceback_util
|
||||
from jax._src.ad_util import (
|
||||
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
||||
from jax._src.api_util import (
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
|
||||
_arg_names)
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
|
||||
_non_static_arg_names)
|
||||
from jax._src.errors import UnexpectedTracerError
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.interpreters import ad
|
||||
@ -636,7 +636,7 @@ def _check_for_aliased_refs(f, nondiff_argnums, args):
|
||||
for i, x in enumerate(leaves):
|
||||
if (isinstance((a := core.get_aval(x)), AbstractRef) and
|
||||
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
|
||||
arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
|
||||
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
|
||||
if arg_names is None:
|
||||
arg_names = [f'flat index {j}' for j in range(len(leaves))]
|
||||
raise ValueError(
|
||||
|
@ -35,8 +35,7 @@ from jax._src import profiler
|
||||
from jax._src import source_info_util
|
||||
from jax._src import compute_on
|
||||
from jax._src import xla_metadata as xla_metadata_lib
|
||||
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
|
||||
fun_sourceinfo)
|
||||
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs)
|
||||
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
|
||||
Var, DropVar, Atom,
|
||||
@ -44,9 +43,9 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent, JaxprEqnContext)
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src import tree_util
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
||||
tree_flatten, tree_structure, generate_key_paths,
|
||||
keystr)
|
||||
tree_flatten, tree_structure)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
merge_lists, partition_list, OrderedSet,
|
||||
as_hashable_function, weakref_lru_cache, subs_list)
|
||||
@ -579,9 +578,9 @@ def trace_to_jaxpr_nounits(
|
||||
# TODO(mattjj): superfluous wrapper...?
|
||||
@lu.transformation2
|
||||
def trace_to_subjaxpr_nounits(
|
||||
f,
|
||||
f: Callable,
|
||||
trace: JaxprTrace,
|
||||
instantiate: bool | Sequence[bool],
|
||||
instantiate: Sequence[bool] | bool,
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
|
||||
@ -607,7 +606,9 @@ def trace_to_subjaxpr_nounits2(
|
||||
del out_tracers
|
||||
return jaxpr, (out_pvals, out_consts, env)
|
||||
|
||||
def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals):
|
||||
def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
|
||||
instantiate: Sequence[bool] | bool,
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
in_knowns = [pval.is_known() for pval in in_pvals]
|
||||
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
|
||||
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
|
||||
@ -1903,7 +1904,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers if primitive.multiple_results else out_tracers.pop()
|
||||
|
||||
def process_call(self, call_primitive, f, explicit_tracers, params):
|
||||
def process_call(self, call_primitive, f: lu.WrappedFun,
|
||||
explicit_tracers, params):
|
||||
if f.in_type is None:
|
||||
f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers))
|
||||
implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
|
||||
@ -1915,7 +1917,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
|
||||
propagate_source_info=False)
|
||||
source_info = source_info_util.current()
|
||||
out_tracers = []
|
||||
out_tracers: list[Tracer] = []
|
||||
for aval, _ in out_type:
|
||||
if type(aval) is DShapedArray:
|
||||
shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else
|
||||
@ -2110,35 +2112,39 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
|
||||
# Callers should be using linear_util.debug_info instead!
|
||||
def tracing_debug_info(
|
||||
fn: Callable,
|
||||
in_tree: PyTreeDef | None,
|
||||
out_tree_thunk: Callable[[], PyTreeDef] | None,
|
||||
in_tree: PyTreeDef,
|
||||
out_tree_thunk: Callable[[], PyTreeDef],
|
||||
has_kwargs: bool,
|
||||
traced_for: str
|
||||
) -> lu.TracingDebugInfo:
|
||||
src_info = fun_sourceinfo(fn)
|
||||
arg_names: tuple[str | None, ...] | None
|
||||
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
|
||||
# We just have to make sure we grad the debugging information when we have
|
||||
# the unflattened args
|
||||
# TODO(necula): in general we can just pretend the leaves are booleans, but
|
||||
# when we use custom pytrees, the flattening functions may check the type
|
||||
# of the argument
|
||||
try:
|
||||
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
|
||||
args, kwargs = dummy_args if has_kwargs else (dummy_args, {})
|
||||
ba = api_util.fun_signature(fn).bind(*args, **kwargs) # type: ignore
|
||||
arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items()
|
||||
for path, _ in generate_key_paths(dummy))
|
||||
except:
|
||||
arg_names = None # TODO(necula): we should not need this
|
||||
def result_paths():
|
||||
try:
|
||||
out_tree = out_tree_thunk()
|
||||
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
|
||||
except:
|
||||
return None # TODO(necula): this does not seem to be needed
|
||||
return tuple(path for path, _ in generate_key_paths(dummy_result))
|
||||
# TODO(necula): clean up the type: ignore below
|
||||
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore[arg-type]
|
||||
# TODO(necula): remove this catch-all. Repro in batching_test:test_basic_jit
|
||||
dummy_args = ([False], {}) if has_kwargs else [False]
|
||||
args, kwargs = dummy_args if has_kwargs else (dummy_args, {}) # type: ignore
|
||||
def res_paths_thunk() -> tuple[str, ...]:
|
||||
out_tree = out_tree_thunk()
|
||||
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
|
||||
return tuple(tree_util.keystr(path)
|
||||
for path, _ in tree_util.generate_key_paths(dummy_result))
|
||||
return api_util.tracing_debug_info(traced_for, fn, args, kwargs,
|
||||
result_paths_thunk=res_paths_thunk)
|
||||
|
||||
def tracing_debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo:
|
||||
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
|
||||
return tracing_debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
|
||||
|
||||
fn_trees = flattened_fun_in_tree(fn)
|
||||
if fn_trees is None:
|
||||
# TODO(necula): eliminate this branch
|
||||
return lu.TracingDebugInfo(traced_for, api_util.fun_sourceinfo(fn.f),
|
||||
(None,), None)
|
||||
in_tree, out_tree_thunk, has_kws = fn_trees
|
||||
return tracing_debug_info(fn.f, in_tree, out_tree_thunk, has_kws, traced_for)
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic(
|
||||
@ -2178,7 +2184,7 @@ def _check_no_returned_refs(
|
||||
raise ValueError(
|
||||
f"function returned a mutable array reference of type {a.str_short()}, "
|
||||
"but mutable array references cannot be returned.")
|
||||
loc = (f' at output tree path {keystr(ls[i])}' # type: ignore
|
||||
loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore
|
||||
if (dbg.result_paths_thunk and
|
||||
(ls := dbg.result_paths_thunk()) and
|
||||
ls[i]) else '')
|
||||
@ -2190,7 +2196,7 @@ def _check_no_returned_refs(
|
||||
origin_info = ('\n\nThe returned mutable array was created on line '
|
||||
f'{source_info_util.summarize(eqn.source_info)}.')
|
||||
elif v in frame.invars:
|
||||
arg_name = dbg.arg_names[frame.invars.index(v)]
|
||||
arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore
|
||||
origin_info = ('\n\nThe returned mutable array was passed in as the '
|
||||
f'argument {arg_name}.')
|
||||
else:
|
||||
|
@ -139,7 +139,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch')
|
||||
dbg = pe.tracing_debug_info(branches[0], ops_tree, None, False, 'switch') # type: ignore
|
||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
@ -238,7 +238,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond')
|
||||
dbg = pe.tracing_debug_info(true_fun, ops_tree, None, False, 'cond') # type: ignore
|
||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
|
@ -275,7 +275,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan')
|
||||
dbg = pe.tracing_debug_info(f, in_tree, None, False, 'scan') # type: ignore
|
||||
in_avals = tuple(_map(core.get_aval, in_flat))
|
||||
_check_no_aliased_ref_args(dbg, in_avals, in_flat)
|
||||
|
||||
|
@ -254,11 +254,19 @@ def fun_name(f):
|
||||
return str(f)
|
||||
|
||||
class TracingDebugInfo(NamedTuple):
|
||||
# Packages up trace/staging-time debug info about a func and its parameters,
|
||||
# formed just before staging to a jaxpr and read in trace-time error messages.
|
||||
"""Tracing-time debugging info about a func and its arguments.
|
||||
|
||||
Formed just before staging to a jaxpr and read in trace-time error messages.
|
||||
"""
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
|
||||
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
|
||||
|
||||
# The paths of the flattened non-static argnames,
|
||||
# e.g. ('x', 'dict_arg["a"]', ... ).
|
||||
# Uses `None` for the args that do not correspond to user-named arguments,
|
||||
# e.g., tangent args in jax.jvp.
|
||||
arg_names: tuple[str | None, ...]
|
||||
|
||||
# e.g. ('[0]', '[1]', ...)
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]] | None
|
||||
|
||||
|
@ -1810,13 +1810,11 @@ def pallas_call(
|
||||
kernel_fun_sig = api_util.fun_signature(kernel)
|
||||
arg_names = None
|
||||
if kernel_fun_sig:
|
||||
kernel_debug_info = api_util.debug_info(
|
||||
kernel_debug_info = api_util.tracing_debug_info(
|
||||
"pallas_call kernel",
|
||||
kernel_src_info,
|
||||
kernel_fun_sig,
|
||||
[1] * len(kernel_fun_sig.parameters), {}, (), ())
|
||||
if kernel_debug_info:
|
||||
arg_names = kernel_debug_info.arg_names
|
||||
kernel,
|
||||
[1] * len(kernel_fun_sig.parameters), {})
|
||||
arg_names = kernel_debug_info.arg_names
|
||||
del kernel_debug_info
|
||||
in_origins = tuple(in_path_to_input_origin(p, arg_names)
|
||||
for p in in_paths)
|
||||
@ -1909,6 +1907,10 @@ def in_path_to_input_origin(
|
||||
if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(
|
||||
arg_names
|
||||
):
|
||||
if arg_names[arg_idx.idx] is None:
|
||||
# TODO(necula): when is this needed?
|
||||
# Repro: pallas_test:test_with_input_output_aliasing
|
||||
return f"args{tree_util.keystr(in_path)}"
|
||||
return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path))
|
||||
else:
|
||||
return f"args{tree_util.keystr(tuple(in_path))}"
|
||||
|
@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, check_callable, resolve_argnums,
|
||||
argnames_partial_except, debug_info, result_paths, add_jaxpr_debug_info,
|
||||
argnames_partial_except, debug_info, tracing_debug_info, result_paths, add_jaxpr_debug_info,
|
||||
hoist_obj_attrs, _check_no_aliased_ref_args,
|
||||
_check_no_aliased_closed_over_refs)
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -176,7 +176,7 @@ class PjitInfo(NamedTuple):
|
||||
return self is other
|
||||
|
||||
|
||||
def _python_pjit_helper(fun, jit_info, *args, **kwargs):
|
||||
def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
|
||||
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
|
||||
|
||||
for arg in args_flat:
|
||||
@ -568,6 +568,14 @@ def _infer_params_impl(
|
||||
|
||||
dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
|
||||
ji.static_argnums, ji.static_argnames)
|
||||
# TODO(necula): replace the above with below.
|
||||
# haiku/_src/integration:hk_transforms_test fails
|
||||
# dbg = tracing_debug_info('jit', fun, args, kwargs,
|
||||
# static_argnums=ji.static_argnums,
|
||||
# static_argnames=ji.static_argnames,
|
||||
# TODO(necula): do we really need this, e.g., for tracing speed
|
||||
# sourceinfo = ji.fun_sourceinfo,
|
||||
# signature = ji.fun_signature)
|
||||
f = lu.wrap_init(fun)
|
||||
f, res_paths = result_paths(f)
|
||||
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
|
||||
@ -732,8 +740,12 @@ def _infer_params(
|
||||
signature, dynargs = jax_jit.parse_arguments(
|
||||
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
||||
ji.static_argnames, tree_util.default_registry)
|
||||
dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
|
||||
ji.static_argnums, ji.static_argnames)
|
||||
dbg = tracing_debug_info('jit', fun, args, kwargs,
|
||||
static_argnums=ji.static_argnums,
|
||||
static_argnames=ji.static_argnames,
|
||||
# TODO(necula): do we really need this, e.g., for tracing speed
|
||||
sourceinfo=ji.fun_sourceinfo,
|
||||
signature=ji.fun_signature)
|
||||
avals = _infer_input_type(fun, dbg, dynargs)
|
||||
entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env)
|
||||
if entry.pjit_params is None:
|
||||
@ -744,7 +756,9 @@ def _infer_params(
|
||||
entry.pjit_params = p
|
||||
return entry.pjit_params, entry.pjit_params.consts + dynargs
|
||||
|
||||
def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]:
|
||||
def _infer_input_type(fun: Callable,
|
||||
dbg: lu.TracingDebugInfo | None,
|
||||
explicit_args) -> tuple[core.AbstractValue, ...]:
|
||||
avals = []
|
||||
try:
|
||||
for i, x in enumerate(explicit_args):
|
||||
@ -1672,7 +1686,7 @@ def _pjit_call_impl_python(
|
||||
if compiled._auto_spmd_lowering and config.enable_checks.value:
|
||||
pxla.check_array_xla_sharding_layout_match(
|
||||
args, compiled._in_shardings, compiled._in_layouts,
|
||||
jaxpr.jaxpr.tracing_debug_info, compiled._kept_var_idx)
|
||||
jaxpr.jaxpr._debug_info, compiled._kept_var_idx)
|
||||
if config.distributed_debug.value:
|
||||
# Defensively only perform fingerprint logic if debug logging is enabled
|
||||
# NOTE(skyewm): I didn't benchmark this
|
||||
|
@ -985,15 +985,18 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
|
||||
return new_invals, out_vals
|
||||
|
||||
def initial_style_jaxpr(
|
||||
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
|
||||
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
|
||||
dbg: api_util.TracingDebugInfo,
|
||||
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
||||
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))
|
||||
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg)
|
||||
|
||||
@weakref_lru_cache
|
||||
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||
def _initial_style_jaxpr(fun: Callable,
|
||||
in_tree: api_util.PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug: api_util.TracingDebugInfo):
|
||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
||||
tree_util.treedef_tuple((in_tree,)))
|
||||
debug = pe.tracing_debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
@ -1001,10 +1004,11 @@ def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||
T = TypeVar('T')
|
||||
def run_state(f: Callable[..., None]) -> Callable[[T], T]:
|
||||
def wrapped(args):
|
||||
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
|
||||
# There may be some uninitialized values here in ref_args.
|
||||
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals)
|
||||
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg)
|
||||
jaxpr = hoist_consts_to_refs(jaxpr_)
|
||||
which_linear = (False,) * (len(consts) + len(ref_args))
|
||||
refs_is_initialized = tuple(r is not uninitialized for r in ref_args)
|
||||
@ -1020,9 +1024,10 @@ def run_state(f: Callable[..., None]) -> Callable[[T], T]:
|
||||
|
||||
def run_state_reference(f: Callable[..., None]):
|
||||
def wrapped(args):
|
||||
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
|
||||
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals)
|
||||
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg)
|
||||
jaxpr = hoist_consts_to_refs(jaxpr_)
|
||||
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
|
||||
|
||||
|
@ -1496,8 +1496,23 @@ class RunStateTest(jtu.JaxTestCase):
|
||||
def test_can_stage_run_state(self):
|
||||
def f(x):
|
||||
return run_state(lambda _: None)(x)
|
||||
jaxpr = jax.make_jaxpr(f)(2)
|
||||
self.assertIsNotNone(jaxpr.jaxpr.debug_info)
|
||||
self.assertIsNotNone(jaxpr.jaxpr.debug_info.func_src_info)
|
||||
|
||||
def test_can_stage_run_state_leaked_tracer_error(self):
|
||||
leaks = []
|
||||
def f(x):
|
||||
def my_fun(x):
|
||||
leaks.append(x)
|
||||
return None
|
||||
return run_state(my_fun)(x)
|
||||
_ = jax.make_jaxpr(f)(2)
|
||||
|
||||
with self.assertRaisesRegex(jax.errors.UnexpectedTracerError,
|
||||
"The function being traced when the value leaked was .*my_fun"):
|
||||
jax.jit(lambda _: leaks[0])(1)
|
||||
|
||||
def test_nested_run_state_captures_effects(self):
|
||||
def f(x):
|
||||
def body(x_ref):
|
||||
|
Loading…
x
Reference in New Issue
Block a user