mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual debug info. This enables us to clean up the type declarations and to remove the checks whether debug_info is present. For usage outside of the JAX internals, we change `jax.extend.linear_util.wrap_init` to be usable without debug_info, for temporary backwards compatibility. We emit a deprecation warning and fill-in some fake debugging info. See https://github.com/jax-ml/jax/issues/26480 for more details. PiperOrigin-RevId: 726770483
This commit is contained in:
parent
60dcded2af
commit
a0812cd57e
@ -37,6 +37,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
This package may safely be removed if it is present on your machine; JAX now
|
||||
uses `libtpu` instead.
|
||||
|
||||
* Deprecations
|
||||
* The internal function `linear_util.wrap_init` and the constructor
|
||||
`core.Jaxpr` now must take a non-empty `core.DebugInfo` kwarg. For
|
||||
a limited time, a `DeprecationWarning` is printed if
|
||||
`jax.extend.linear_util.wrap_init` is used without debugging info.
|
||||
A downstream effect of this several other internal functions need debug
|
||||
info. This change does not affect public APIs.
|
||||
See https://github.com/jax-ml/jax/issues/26480 for more detail.
|
||||
|
||||
## jax 0.5.0 (Jan 17, 2025)
|
||||
|
||||
As of this release, JAX now uses
|
||||
|
@ -620,7 +620,7 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
|
||||
return None
|
||||
|
||||
def save_wrapped_fun_sourceinfo(wrapper: Callable,
|
||||
wrapped: Callable | core.DebugInfo | None) -> None:
|
||||
wrapped: Callable | core.DebugInfo) -> None:
|
||||
# Prefer this to functools.wraps because it does not create a reference to
|
||||
# the wrapped function.
|
||||
if isinstance(wrapped, core.DebugInfo):
|
||||
@ -628,7 +628,7 @@ def save_wrapped_fun_sourceinfo(wrapper: Callable,
|
||||
elif callable(wrapped):
|
||||
func_src_info = fun_sourceinfo(wrapped)
|
||||
else:
|
||||
return
|
||||
assert False, wrapped # Unreachable
|
||||
setattr(wrapper, "__fun_sourceinfo__", func_src_info)
|
||||
|
||||
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
|
||||
@ -716,7 +716,7 @@ def register_class_with_attrs(t: type) -> None:
|
||||
_class_with_attrs: set[type] = set()
|
||||
|
||||
# TODO(mattjj): make this function faster
|
||||
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
|
||||
def _check_no_aliased_ref_args(dbg: core.DebugInfo, avals, args):
|
||||
assert config.mutable_array_checks.value
|
||||
refs: dict[int, int] = {}
|
||||
for i, (a, x) in enumerate(zip(avals, args)):
|
||||
@ -730,7 +730,7 @@ def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
|
||||
if dbg else
|
||||
f"at both flat index {dup_idx} and flat index {i}") from None
|
||||
|
||||
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
|
||||
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> None:
|
||||
assert config.mutable_array_checks.value
|
||||
refs: set[int] = {id(core.get_referent(c)) for c in consts
|
||||
if isinstance(core.get_aval(c), AbstractRef)}
|
||||
|
@ -94,7 +94,7 @@ class Jaxpr:
|
||||
_outvars: list[Atom]
|
||||
_eqns: list[JaxprEqn]
|
||||
_effects: Effects
|
||||
_debug_info: DebugInfo | None
|
||||
_debug_info: DebugInfo
|
||||
|
||||
@property
|
||||
def constvars(self) -> list[Var]:
|
||||
@ -117,13 +117,17 @@ class Jaxpr:
|
||||
return self._effects
|
||||
|
||||
@property
|
||||
def debug_info(self) -> DebugInfo | None:
|
||||
def debug_info(self) -> DebugInfo:
|
||||
return self._debug_info
|
||||
|
||||
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
|
||||
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
|
||||
effects: Effects = no_effects,
|
||||
debug_info: DebugInfo | None = None):
|
||||
# We want all calls to pass a DebugInfo object, but for backwards
|
||||
# compatibility we have to allow calls when the debug_info
|
||||
# is missing.
|
||||
debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment]
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
constvars: list of variables introduced for constants. Array constants are
|
||||
@ -134,14 +138,16 @@ class Jaxpr:
|
||||
eqns: list of equations.
|
||||
effects: set of effects. The effects on a jaxpr are a superset of the
|
||||
union of the effects for each equation.
|
||||
debug_info: optional DebugInfo.
|
||||
debug_info: debugging information.
|
||||
"""
|
||||
self._constvars = list(constvars)
|
||||
self._invars = list(invars)
|
||||
self._outvars = list(outvars)
|
||||
self._eqns = list(eqns)
|
||||
self._effects = effects
|
||||
self._debug_info = debug_info and debug_info.resolve_result_paths()
|
||||
# TODO(https://github.com/jax-ml/jax/issues/26480)
|
||||
debug_info = debug_info or lu._missing_debug_info("core.Jaxpr")
|
||||
self._debug_info = debug_info.resolve_result_paths()
|
||||
# TODO(necula): re-enable these safety checks
|
||||
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
|
||||
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
|
||||
|
@ -30,8 +30,8 @@ from jax._src import traceback_util
|
||||
from jax._src.ad_util import (
|
||||
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
||||
from jax._src.api_util import (
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
|
||||
_non_static_arg_names, prepend_static_args, debug_info)
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs,
|
||||
prepend_static_args, debug_info)
|
||||
from jax._src.errors import UnexpectedTracerError
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.interpreters import ad
|
||||
@ -686,7 +686,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
|
||||
@lu.transformation2
|
||||
def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
|
||||
debug_info: core.DebugInfo | None, *args):
|
||||
debug_info: core.DebugInfo, *args):
|
||||
_check_for_aliased_refs(f, nondiff_argnums, debug_info, args)
|
||||
out = f(*args)
|
||||
_check_for_returned_refs(f, out, 'primal')
|
||||
@ -694,20 +694,14 @@ def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
|
||||
|
||||
def _check_for_aliased_refs(f: Callable,
|
||||
nondiff_argnums: Sequence[int],
|
||||
debug: core.DebugInfo | None,
|
||||
debug: core.DebugInfo,
|
||||
args):
|
||||
leaves = tree_leaves(args)
|
||||
refs: dict[int, int] = {}
|
||||
for i, x in enumerate(leaves):
|
||||
if (isinstance((a := core.get_aval(x)), AbstractRef) and
|
||||
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
|
||||
if debug is not None:
|
||||
arg_names = debug.safe_arg_names(len(leaves))
|
||||
else:
|
||||
# TODO(necula): drop this branch
|
||||
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
|
||||
if arg_names is None:
|
||||
arg_names = [f'flat index {j}' for j in range(len(leaves))]
|
||||
arg_names = debug.safe_arg_names(len(leaves))
|
||||
raise ValueError(
|
||||
"only one reference to a mutable array may be passed as an argument "
|
||||
f"to a function, but custom_vjp function {f} got the same mutable "
|
||||
@ -763,8 +757,8 @@ def _check_for_tracers(x):
|
||||
def _flatten_fwd(f: Callable, store: lu.EqualStore,
|
||||
nondiff_argnums: Sequence[int],
|
||||
symbolic_zeros: bool,
|
||||
debug_primal: core.DebugInfo | None,
|
||||
debug_fwd: core.DebugInfo | None,
|
||||
debug_primal: core.DebugInfo,
|
||||
debug_fwd: core.DebugInfo,
|
||||
in_tree: PyTreeDef, maybe_out_type, *args):
|
||||
primal_name = debug_primal.func_name if debug_primal else str(f)
|
||||
fwd_name = debug_fwd.func_name if debug_fwd else "<unknown>"
|
||||
@ -1560,9 +1554,9 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")
|
||||
# simpler, but it would be worth revisiting this.
|
||||
def optimize_remat_of_custom_vjp_fwd(
|
||||
fun: Callable[..., ReturnValue],
|
||||
debug_fun: core.DebugInfo | None,
|
||||
debug_fun: core.DebugInfo,
|
||||
fwd: Callable[..., tuple[ReturnValue, Any]],
|
||||
debug_fwd: core.DebugInfo | None,
|
||||
debug_fwd: core.DebugInfo,
|
||||
nondiff_argnums: Sequence[int] = (),
|
||||
symbolic_zeros: bool = False,
|
||||
) -> Callable[..., tuple[ReturnValue, Any]]:
|
||||
|
@ -86,7 +86,7 @@ def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
|
||||
@lu.transformation_with_aux2
|
||||
def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
|
||||
nzs_in: Sequence[bool],
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
*primals, **params):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
tangent_trace = pe.DynamicJaxprTrace(debug_info)
|
||||
@ -133,7 +133,7 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
|
||||
return out_primals, out_tangents
|
||||
|
||||
def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
|
||||
dbg = jaxpr.debug_info._replace(
|
||||
arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars))
|
||||
return core.Jaxpr(constvars=(),
|
||||
invars=jaxpr.invars + jaxpr.constvars,
|
||||
@ -768,7 +768,7 @@ def linearize_from_jvp(jvp: Callable,
|
||||
multiple_results: bool,
|
||||
nonzeros: Sequence[bool],
|
||||
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
primals, params):
|
||||
current_name_stack = source_info_util.current_name_stack()
|
||||
with core.take_current_trace() as parent_trace:
|
||||
@ -1100,15 +1100,14 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
|
||||
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
|
||||
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
|
||||
new_debug_info = jaxpr.jaxpr.debug_info
|
||||
if new_debug_info is not None:
|
||||
new_arg_names = tuple(_perm(primals_in, tangents_in,
|
||||
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
|
||||
new_result_paths = tuple(_perm(primals_out, tangents_out,
|
||||
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
|
||||
new_debug_info = new_debug_info._replace(
|
||||
arg_names=new_arg_names,
|
||||
result_paths=new_result_paths,
|
||||
)
|
||||
new_arg_names = tuple(_perm(primals_in, tangents_in,
|
||||
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
|
||||
new_result_paths = tuple(_perm(primals_out, tangents_out,
|
||||
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
|
||||
new_debug_info = new_debug_info._replace(
|
||||
arg_names=new_arg_names,
|
||||
result_paths=new_result_paths,
|
||||
)
|
||||
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
|
||||
new_invars, new_outvars, jaxpr.jaxpr.eqns,
|
||||
jaxpr.jaxpr.effects,
|
||||
|
@ -502,7 +502,8 @@ def _closed_call_param_updater(params, _, __):
|
||||
return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
|
||||
call_param_updaters[core.closed_call_p] = _closed_call_param_updater
|
||||
|
||||
def abstract_eval_fun(fun: Callable, *avals, debug_info=None, **params):
|
||||
def abstract_eval_fun(fun: Callable, *avals,
|
||||
debug_info: core.DebugInfo, **params):
|
||||
_, avals_out, _, () = trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(fun, params, debug_info=debug_info), avals)
|
||||
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
|
||||
@ -582,7 +583,7 @@ def trace_to_subjaxpr_nounits(
|
||||
f: Callable,
|
||||
trace: JaxprTrace,
|
||||
instantiate: Sequence[bool] | bool,
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
|
||||
@ -595,7 +596,7 @@ def trace_to_subjaxpr_nounits(
|
||||
def trace_to_subjaxpr_nounits2(
|
||||
f: Callable,
|
||||
tag: TraceTag,
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert isinstance(tag, TraceTag)
|
||||
@ -612,7 +613,7 @@ def trace_to_subjaxpr_nounits2(
|
||||
def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
|
||||
instantiate: Sequence[bool] | bool,
|
||||
in_pvals: Sequence[PartialVal],
|
||||
debug_info: core.DebugInfo | None):
|
||||
debug_info: core.DebugInfo):
|
||||
in_knowns = [pval.is_known() for pval in in_pvals]
|
||||
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
|
||||
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
|
||||
@ -639,7 +640,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
|
||||
def trace_to_subjaxpr_nounits_fwd(
|
||||
f: Callable,
|
||||
tag: TraceTag,
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
@ -669,7 +670,7 @@ def trace_to_subjaxpr_nounits_fwd(
|
||||
def trace_to_subjaxpr_nounits_fwd2(
|
||||
f: Callable,
|
||||
tag: TraceTag,
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
@ -752,13 +753,14 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
|
||||
def tracers_to_jaxpr(
|
||||
in_tracers: Sequence[JaxprTracer],
|
||||
out_tracers: Sequence[JaxprTracer],
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
|
||||
"""Constructs Jaxpr given tracers for inputs and outputs.
|
||||
|
||||
Params:
|
||||
in_tracers: the tracers that were created for the function inputs
|
||||
out_tracers: the tracers that were output by the function.
|
||||
debug_info: the debug info for the function.
|
||||
|
||||
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
|
||||
the `constvars` in the returned Jaxps, and a list of environment values.
|
||||
@ -838,7 +840,7 @@ def tracers_to_jaxpr(
|
||||
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
|
||||
"""Moves the constvars to the start of invars."""
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
|
||||
dbg = jaxpr.debug_info._replace(
|
||||
arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
|
||||
lifted_jaxpr = Jaxpr(constvars=(),
|
||||
invars=jaxpr.constvars + jaxpr.invars,
|
||||
@ -854,7 +856,7 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
|
||||
return jaxpr.replace() # 'return jaxpr' would create cache reference cycle
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
constvars, invars = split_list(jaxpr.invars, [n])
|
||||
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
|
||||
dbg = jaxpr.debug_info._replace(
|
||||
arg_names=jaxpr.debug_info.arg_names[n:])
|
||||
lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars,
|
||||
debug_info=dbg)
|
||||
@ -868,7 +870,7 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
|
||||
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
|
||||
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
|
||||
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
||||
effects=jaxpr.effects)
|
||||
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
|
||||
config.enable_checks.value and core.check_jaxpr(converted_jaxpr)
|
||||
return converted_jaxpr
|
||||
|
||||
@ -1363,7 +1365,7 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
|
||||
|
||||
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
|
||||
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
|
||||
dbg = jaxpr.debug_info and core.DebugInfo(
|
||||
dbg = core.DebugInfo(
|
||||
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||
jaxpr.debug_info.arg_names,
|
||||
jaxpr.debug_info.filter_result_paths(used_outputs))
|
||||
@ -1451,7 +1453,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
|
||||
eqns = new_eqns[::-1]
|
||||
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
|
||||
|
||||
dbg = jaxpr.debug_info and core.DebugInfo(
|
||||
dbg = core.DebugInfo(
|
||||
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||
jaxpr.debug_info.filter_arg_names(used_inputs),
|
||||
jaxpr.debug_info.filter_result_paths(used_outputs))
|
||||
@ -1653,9 +1655,9 @@ class JaxprStackFrame:
|
||||
attrs_tracked: list[tuple[Any, str]]
|
||||
attrs_inits: list
|
||||
attrs_vars: list[Var]
|
||||
debug_info: core.DebugInfo | None
|
||||
debug_info: core.DebugInfo
|
||||
|
||||
def __init__(self, debug_info: core.DebugInfo | None):
|
||||
def __init__(self, debug_info: core.DebugInfo):
|
||||
self.gensym = core.gensym()
|
||||
self.tracer_to_var = {}
|
||||
self.constid_to_tracer = {}
|
||||
@ -1674,7 +1676,7 @@ class JaxprStackFrame:
|
||||
|
||||
def to_jaxpr(self, trace: DynamicJaxprTrace,
|
||||
out_tracers: Sequence[Tracer],
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
||||
@ -1696,7 +1698,7 @@ class JaxprStackFrame:
|
||||
return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)
|
||||
|
||||
def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
|
||||
debug_info: core.DebugInfo | None):
|
||||
debug_info: core.DebugInfo):
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
@ -1843,7 +1845,7 @@ def _inline_literals(
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = ("frame",)
|
||||
|
||||
def __init__(self, debug_info: core.DebugInfo | None):
|
||||
def __init__(self, debug_info: core.DebugInfo):
|
||||
self.frame = JaxprStackFrame(debug_info)
|
||||
|
||||
def invalidate(self):
|
||||
@ -2117,7 +2119,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
return out_tracers
|
||||
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer],
|
||||
debug_info: core.DebugInfo | None):
|
||||
debug_info: core.DebugInfo):
|
||||
return self.frame.to_jaxpr(self, out_tracers, debug_info)
|
||||
|
||||
|
||||
@ -2180,17 +2182,13 @@ def trace_to_jaxpr_dynamic(
|
||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
|
||||
|
||||
def _check_no_returned_refs(
|
||||
dbg: core.DebugInfo | None,
|
||||
dbg: core.DebugInfo,
|
||||
out_tracers: Sequence[DynamicJaxprTracer]
|
||||
) -> None:
|
||||
if not config.mutable_array_checks.value: return
|
||||
for i, t in enumerate(out_tracers):
|
||||
a = t.aval
|
||||
if isinstance(a, AbstractRef):
|
||||
if dbg is None:
|
||||
raise ValueError(
|
||||
f"function returned a mutable array reference of type {a.str_short()}, "
|
||||
"but mutable array references cannot be returned.")
|
||||
result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
|
||||
loc = result_paths[i] and f' at output tree path {result_paths[i]}'
|
||||
frame = t._trace.frame
|
||||
@ -2469,7 +2467,8 @@ def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
|
||||
return aval
|
||||
|
||||
in_avals = [substitute(v.aval) for v in jaxpr.invars]
|
||||
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
|
||||
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts),
|
||||
debug_info=jaxpr.debug_info)
|
||||
padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals)
|
||||
return padded_jaxpr, padded_consts
|
||||
|
||||
|
@ -877,8 +877,8 @@ def lower_parallel_callable(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
|
||||
result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
|
||||
arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
|
||||
result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
|
||||
num_replicas=replicas.num_global_replicas,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return PmapComputation(lowering_result.module,
|
||||
@ -1968,8 +1968,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
result_shardings=out_mlir_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
|
||||
result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
|
||||
arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
|
||||
result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
|
||||
num_replicas=nreps,
|
||||
num_partitions=num_partitions,
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
@ -2125,7 +2125,7 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
|
||||
class AllArgsInfo(NamedTuple):
|
||||
"""Avals and debug_info for all arguments prior to DCE."""
|
||||
in_avals: Sequence[core.ShapedArray]
|
||||
debug_info: core.DebugInfo | None
|
||||
debug_info: core.DebugInfo
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
@ -3202,17 +3202,13 @@ def cc_shard_arg(x, sharding, layout):
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
jaxpr_debug_info: core.DebugInfo | None = None):
|
||||
jaxpr_debug_info: core.DebugInfo):
|
||||
if len(ref_avals) != len(arg_avals):
|
||||
raise TypeError(
|
||||
f"Computation compiled for {len(ref_avals)} inputs "
|
||||
f"but called with {len(arg_avals)}")
|
||||
|
||||
if jaxpr_debug_info is not None:
|
||||
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
|
||||
else:
|
||||
num_args = len(ref_avals)
|
||||
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
|
||||
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
|
||||
|
||||
errors = []
|
||||
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
|
||||
@ -3264,14 +3260,13 @@ def check_array_xla_sharding_layout_match(
|
||||
args_after_dce,
|
||||
in_xla_shardings: Sequence[JSharding],
|
||||
in_xla_layouts: Sequence[DeviceLocalLayout],
|
||||
jaxpr_debug_info: core.DebugInfo | None,
|
||||
jaxpr_debug_info: core.DebugInfo,
|
||||
kept_var_idx: set[int]) -> None:
|
||||
from jax._src.array import ArrayImpl
|
||||
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
|
||||
arg_names = (
|
||||
[""] * len(args_after_dce) if jaxpr_debug_info is None
|
||||
else [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore
|
||||
if i in kept_var_idx]
|
||||
[a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore
|
||||
if i in kept_var_idx]
|
||||
)
|
||||
errors = []
|
||||
num_errors = 5
|
||||
|
@ -73,7 +73,7 @@ for_p.skip_canonicalization = True
|
||||
|
||||
def _trace_to_jaxpr_with_refs(f: Callable, state_tree: PyTreeDef,
|
||||
state_avals: Sequence[core.AbstractValue],
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
||||
f, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(f, debug_info=debug_info),
|
||||
|
@ -336,7 +336,7 @@ def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
|
||||
|
||||
|
||||
def _tangent_linear_map(func: Callable, params, params_dot,
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
*x):
|
||||
"""Compute the tangent of a linear map.
|
||||
|
||||
|
@ -161,7 +161,7 @@ class WrappedFun:
|
||||
f_transformed: Callable,
|
||||
transforms,
|
||||
stores: tuple[Store | EqualStore | None, ...], params, in_type,
|
||||
debug_info: DebugInfo | None):
|
||||
debug_info: DebugInfo):
|
||||
self.f = f
|
||||
self.f_transformed = f_transformed
|
||||
self.transforms = transforms
|
||||
@ -258,6 +258,7 @@ def fun_name(f):
|
||||
except:
|
||||
return str(f)
|
||||
|
||||
|
||||
class DebugInfo(NamedTuple):
|
||||
"""Debugging info about a func, its arguments, and results."""
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
@ -331,18 +332,17 @@ def _missing_debug_info(for_what: str) -> DebugInfo:
|
||||
return DebugInfo("missing_debug_info", "<missing_debug_info>", (), ())
|
||||
|
||||
def wrap_init(f: Callable, params=None, *,
|
||||
debug_info: DebugInfo | None = None) -> WrappedFun:
|
||||
debug_info: DebugInfo) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
params_dict = {} if params is None else params
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
|
||||
if debug_info:
|
||||
if debug_info.result_paths is None:
|
||||
fun, result_paths_thunk = _get_result_paths_thunk(fun)
|
||||
debug_info = debug_info._replace(
|
||||
result_paths=HashableFunction(result_paths_thunk, closure=()))
|
||||
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
|
||||
fun.params, fun.in_type, debug_info)
|
||||
if debug_info.result_paths is None:
|
||||
fun, result_paths_thunk = _get_result_paths_thunk(fun)
|
||||
debug_info = debug_info._replace(
|
||||
result_paths=HashableFunction(result_paths_thunk, closure=()))
|
||||
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
|
||||
fun.params, fun.in_type, debug_info)
|
||||
return fun
|
||||
|
||||
|
||||
|
@ -731,22 +731,20 @@ def _infer_params(
|
||||
entry.pjit_params = p
|
||||
return entry.pjit_params, entry.pjit_params.consts + dynargs
|
||||
|
||||
def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None,
|
||||
def _infer_input_type(fun: Callable, dbg: core.DebugInfo,
|
||||
explicit_args) -> tuple[core.AbstractValue, ...]:
|
||||
avals = []
|
||||
try:
|
||||
for i, x in enumerate(explicit_args):
|
||||
avals.append(core.shaped_abstractify(x))
|
||||
except OverflowError:
|
||||
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg
|
||||
else f"flattened argument number is {i}")
|
||||
arg_path = f"argument path is {dbg.arg_names[i]}" # type: ignore
|
||||
raise OverflowError(
|
||||
"An overflow was encountered while parsing an argument to a jitted "
|
||||
f"computation, whose {arg_path}."
|
||||
) from None
|
||||
except TypeError:
|
||||
arg_description = (f"path {dbg.arg_names[i]}" if dbg
|
||||
else f"flattened argument number {i}")
|
||||
arg_description = f"path {dbg.arg_names[i]}" # type: ignore
|
||||
raise TypeError(
|
||||
f"Error interpreting argument to {fun} as an abstract array."
|
||||
f" The problematic value is of type {type(x)} and was passed to"
|
||||
@ -1111,7 +1109,7 @@ class PytreeLeaf:
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
in_layouts_treedef, in_layouts_leaves,
|
||||
in_avals, in_tree, debug_info,
|
||||
in_avals, in_tree, debug_info: core.DebugInfo,
|
||||
device_or_backend_set, kws):
|
||||
if not kws:
|
||||
in_tree, _ = treedef_children(in_tree)
|
||||
@ -1136,11 +1134,11 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals)
|
||||
if not config.dynamic_shapes.value and not attrs_tracked:
|
||||
pjit_check_aval_sharding(in_shardings_flat, in_avals,
|
||||
None if debug_info is None else debug_info.safe_arg_names(len(in_avals)),
|
||||
debug_info.safe_arg_names(len(in_avals)),
|
||||
"pjit arguments", allow_uneven_sharding=False)
|
||||
check_aval_layout_compatibility(
|
||||
in_layouts_flat, in_avals,
|
||||
None if debug_info is None else debug_info.arg_names, "jit arguments")
|
||||
debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type]
|
||||
return in_shardings_flat, in_layouts_flat
|
||||
|
||||
callsites: set[str] = set()
|
||||
@ -1167,7 +1165,7 @@ def explain_tracing_cache_miss(
|
||||
|
||||
# have we seen this function before at all?
|
||||
fun_name = getattr(fun.f, '__qualname__', fun.f)
|
||||
if debug_info is not None and debug_info.func_src_info:
|
||||
if debug_info.func_src_info:
|
||||
# TODO(necula): clean up the extraction of the source info
|
||||
_, *rest = debug_info.func_src_info.split(' at ')
|
||||
src_info = " defined at " + ' '.join(rest)
|
||||
@ -1239,7 +1237,7 @@ def explain_tracing_cache_miss(
|
||||
# have we never seen these input types (eg shapes, dtypes) before?
|
||||
types_match = [k for k in trees_match if k[1] == in_type]
|
||||
if not types_match:
|
||||
if len(in_type) < 5 and debug_info is not None:
|
||||
if len(in_type) < 5:
|
||||
in_type_str = ':\n {}'.format(', '.join(
|
||||
f'{n}: {ty.str_short(short_dtypes=True)}'
|
||||
for n, ty in zip(debug_info.arg_names, in_type)))
|
||||
@ -1251,10 +1249,7 @@ def explain_tracing_cache_miss(
|
||||
num_mismatch = sum(map(op.ne, closest_ty, in_type))
|
||||
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
|
||||
add_weak_type_hint = False
|
||||
if debug_info:
|
||||
arg_names = debug_info.safe_arg_names(len(in_type))
|
||||
else:
|
||||
arg_names = (None,) * len(in_type)
|
||||
arg_names = debug_info.safe_arg_names(len(in_type))
|
||||
|
||||
for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
|
||||
if ty1 != ty2:
|
||||
@ -1320,7 +1315,7 @@ def _create_pjit_jaxpr(
|
||||
def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
|
||||
out_layouts_leaves, out_tree, out_avals,
|
||||
debug_info: core.DebugInfo | None,
|
||||
debug_info: core.DebugInfo,
|
||||
device_or_backend_set):
|
||||
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
|
||||
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):
|
||||
@ -1340,11 +1335,11 @@ def _check_and_canonicalize_out_shardings(
|
||||
if not config.dynamic_shapes.value:
|
||||
pjit_check_aval_sharding(
|
||||
out_shardings_flat, out_avals,
|
||||
None if debug_info is None else debug_info.safe_result_paths(len(out_avals)),
|
||||
debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
|
||||
"pjit outputs", allow_uneven_sharding=False)
|
||||
check_aval_layout_compatibility(
|
||||
out_layouts_flat, out_avals,
|
||||
None if debug_info is None else debug_info.safe_result_paths(len(out_avals)),
|
||||
debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
|
||||
"jit outputs")
|
||||
return out_shardings_flat, out_layouts_flat
|
||||
|
||||
|
@ -15,6 +15,8 @@
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from jax._src.linear_util import (
|
||||
StoreException as StoreException,
|
||||
WrappedFun as WrappedFun,
|
||||
@ -24,7 +26,14 @@ from jax._src.linear_util import (
|
||||
transformation_with_aux as transformation_with_aux,
|
||||
transformation2 as transformation2,
|
||||
transformation_with_aux2 as transformation_with_aux2,
|
||||
wrap_init as wrap_init,
|
||||
# TODO(b/396086979): remove this once we pass debug_info everywhere.
|
||||
wrap_init as _wrap_init,
|
||||
_missing_debug_info as _missing_debug_info,
|
||||
)
|
||||
|
||||
# Version of wrap_init that does not require a DebugInfo object.
|
||||
# This usage is deprecated, use api_util.debug_info() to construct a proper
|
||||
# DebugInfo object.
|
||||
def wrap_init(f: Callable, params=None, *, debug_info=None) -> WrappedFun:
|
||||
debug_info = debug_info or _missing_debug_info("linear_util.wrap_init")
|
||||
return _wrap_init(f, params, debug_info=debug_info)
|
||||
|
@ -71,8 +71,7 @@ def _collect_jaxprs(jaxpr: core.Jaxpr,
|
||||
return acc
|
||||
|
||||
|
||||
def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]:
|
||||
if dbg is None: return "None"
|
||||
def _debug_info_to_string(dbg: core.DebugInfo) -> list[str]:
|
||||
# Strip the absolute path and the line number but check that it references
|
||||
# this file (to catch errors when the source info points in JAX internals)
|
||||
func_src_info = re.sub(r"^(\S+)( at .*.debug_info_test.py:\d+)?", "\\1", dbg.func_src_info)
|
||||
@ -294,7 +293,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
def wrapper(x, y):
|
||||
return x
|
||||
|
||||
api_util.save_wrapped_fun_sourceinfo(wrapper, None) # No effect
|
||||
dbg = api_util.debug_info("test", wrapper, (1, 2), {})
|
||||
self.assertEqual("wrapper", dbg.func_name)
|
||||
|
||||
|
@ -53,7 +53,8 @@ class ExtendTest(jtu.JaxTestCase):
|
||||
self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux)
|
||||
self.assertIs(jex.linear_util.transformation, linear_util.transformation)
|
||||
self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux)
|
||||
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
|
||||
# TODO(necula): revert this change once we deprecate the old wrap_init
|
||||
# self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
|
||||
|
||||
|
||||
class RandomTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user