[better_errors] Add debug info to the Jaxprs formed for AD

Following #26078 , we add debug info to more calls of lu.wrap_init.
This commit is contained in:
George Necula 2025-01-24 12:53:51 +02:00
parent 414449e142
commit abcaec7081
22 changed files with 480 additions and 169 deletions

View File

@ -564,6 +564,7 @@ pytype_strict_library(
srcs = ["_src/interpreters/mlir.py"],
deps = [
":ad_util",
":api_util",
":config",
":core",
":dtypes",

View File

@ -682,12 +682,13 @@ def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],
return _transpose_jaxpr(jaxpr, tuple(in_linear), tuple(out_zeros))
@weakref_lru_cache
def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
def _transpose_jaxpr(jaxpr: core.ClosedJaxpr,
in_lin: Sequence[bool],
out_zeros: Sequence[bool]):
in_avals = ([a for a, lin in zip(jaxpr.in_avals, in_lin ) if not lin] +
[a for a, zero in zip(jaxpr.out_avals, out_zeros) if not zero])
cell = lambda: None
@lu.wrap_init
def transposed(*args_flat):
ins_flat, out_cts_flat = split_list(args_flat, [len(in_lin) - sum(in_lin)])
@ -715,7 +716,10 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
in_cts_nz, _ = partition_list(in_zeros, in_cts)
return in_cts_nz
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
transposed_wrapped = lu.wrap_init(transposed,
debug_info=jaxpr.jaxpr.debug_info)
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(
transposed_wrapped, in_avals)
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error

View File

@ -983,7 +983,7 @@ def vmap(fun: F,
"to the positional arguments passed to the function, "
f"but got {len(in_axes)=}, {len(args)=}")
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=debug_info("vmap", fun, args, kwargs))
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
axis_size_ = (axis_size if axis_size is not None else
@ -1715,15 +1715,15 @@ def jvp(
0.19900084
"""
check_callable(fun)
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
"""Variant of jvp() that takes an lu.WrappedFun."""
if (not isinstance(primals, (tuple, list)) or
not isinstance(tangents, (tuple, list))):
raise TypeError("primal and tangent arguments to jax.jvp must be tuples or lists; "
f"found {type(primals).__name__} and {type(tangents).__name__}.")
return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})),
primals, tangents, has_aux=has_aux)
def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
"""Variant of jvp() that takes an lu.WrappedFun."""
ps_flat, tree_def = tree_flatten(primals)
ts_flat, tree_def_2 = tree_flatten(tangents)
if tree_def != tree_def_2:
@ -1835,7 +1835,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False
-6.676704
"""
check_callable(fun)
f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=debug_info("linearize", fun, primals, {}))
primals_flat, in_tree = tree_flatten(primals)
if has_aux:
jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree)
@ -1983,8 +1983,9 @@ def vjp(
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
del reduce_axes
check_callable(fun)
return _vjp(
lu.wrap_init(fun), *primals, has_aux=has_aux)
wrapped_fun = lu.wrap_init(fun,
debug_info=debug_info("vjp", fun, primals, {}))
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
"""Variant of vjp() that takes an lu.WrappedFun."""
@ -2049,7 +2050,10 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
raise NotImplementedError("reduce_axes argument to transpose is deprecated")
del reduce_axes
primals_flat, in_tree = tree_flatten(primals)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = flatten_fun_nokwargs(
lu.wrap_init(fun,
debug_info=debug_info("linear_transpose", fun, primals, {})),
in_tree)
in_avals = map(shaped_abstractify, primals_flat)
in_dtypes = map(dtypes.dtype, in_avals)

View File

@ -67,7 +67,8 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
return tuple(map(_ensure_str, x))
@lu.transformation_with_aux2
def flatten_fun(f, store, in_tree, *args_flat):
def flatten_fun(f: Callable, store: lu.Store,
in_tree: PyTreeDef, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = f(*py_args, **py_kwargs)
ans, out_tree = tree_flatten(ans)
@ -587,8 +588,8 @@ def debug_info(
args: Sequence[Any],
kwargs: dict[str, Any],
*,
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = (),
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing?
sourceinfo: str | None = None,

View File

@ -361,8 +361,9 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
else:
jaxpr, consts = call_jaxpr, ()
consts_ = tuple(HashableWrapper(c) for c in consts)
partial_checkify = lu.hashable_partial(lu.wrap_init(
checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree)
partial_checkify = lu.hashable_partial(
lu.wrap_init(checkify_jaxpr_flat_hashable, debug_info=jaxpr.debug_info),
jaxpr, consts_, enabled_errors, err_tree)
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
partial_checkify)
@ -746,7 +747,7 @@ def jaxpr_to_checkify_jaxpr(
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
jaxpr.consts, enabled_errors,
err_tree)
fun = lu.wrap_init(checkify_jaxpr_partial)
fun = lu.wrap_init(checkify_jaxpr_partial, debug_info=jaxpr.jaxpr.debug_info)
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)

View File

@ -2416,8 +2416,9 @@ call_p.def_impl(call_impl)
class ClosedCallPrimitive(CallPrimitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
jaxpr: ClosedJaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
debug_info=jaxpr.jaxpr.debug_info)
return [subfun], new_params
closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')

View File

@ -31,7 +31,7 @@ from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_non_static_arg_names, prepend_static_args)
_non_static_arg_names, prepend_static_args, debug_info)
from jax._src.errors import UnexpectedTracerError
from jax._src.state.types import AbstractRef
from jax._src.interpreters import ad
@ -44,7 +44,7 @@ from jax._src.lax import lax
from jax._src.tree_util import (
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves, tree_flatten_with_path,
tree_leaves_with_path, keystr, treedef_children)
tree_leaves_with_path, keystr, treedef_children, PyTreeDef)
from jax._src.util import (cache, safe_zip, safe_map, split_list, unzip2,
weakref_lru_cache)
@ -78,7 +78,9 @@ _stop_gradient = partial(
# like the api_util.py function, but also grabs output avals for error checking
@lu.transformation_with_aux2
def _flatten_fun_nokwargs(f, store, in_tree, *args_flat):
def _flatten_fun_nokwargs(f: Callable,
store: lu.Store, in_tree: PyTreeDef,
*args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = f(*py_args)
ans_flat, ans_tree = tree_flatten(ans)
@ -204,7 +206,7 @@ class custom_jvp(Generic[ReturnValue]):
*jvps: a sequence of functions, one for each positional argument of the
:class:`~jax.custom_jvp` function. Each function takes as arguments
the tangent value for the corresponding primal input, the primal
output, and the ßprimal inputs. See the example below.
output, and the primal inputs. See the example below.
Returns:
None.
@ -239,28 +241,40 @@ class custom_jvp(Generic[ReturnValue]):
@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
debug = debug_info("custom_jvp fun", self.fun, args, kwargs,
static_argnums=self.nondiff_argnums)
primal_name = debug.func_name
if not self.jvp:
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
raise AttributeError(msg)
jvp_name = getattr(self.jvp, '__name__', str(self.jvp))
args = resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
nondiff_argnums = set(self.nondiff_argnums)
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
for i, x in enumerate(args))
diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args,
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun, debug_info=debug),
diff_argnums, args,
require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
jvp = prepend_static_args(lu.wrap_init(self.jvp), static_args)
diff_args = [args[i] for i, a in enumerate(args) if i not in self.nondiff_argnums]
debug_jvp = debug_info("custom_jvp jvp", self.jvp,
(*static_args, diff_args, diff_args),
{},
static_argnums=tuple(range(len(static_args))))
jvp = prepend_static_args(lu.wrap_init(self.jvp,
debug_info=debug_jvp), static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
jvp = lu.wrap_init(self.jvp)
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args
debug_jvp = debug_info("custom_jvp jvp", self.jvp,
(args, args),
{})
jvp = lu.wrap_init(self.jvp, debug_info=debug_jvp)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree)
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, jvp_name, in_tree,
out_type1)
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, debug_jvp.func_name,
in_tree, out_type1)
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat,
symbolic_zeros=self.symbolic_zeros)
_, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2)
@ -611,15 +625,20 @@ class custom_vjp(Generic[ReturnValue]):
@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
debug_fun = debug_info("custom_vjp fun", self.fun, args, kwargs,
static_argnums=self.nondiff_argnums)
if not self.fwd or not self.bwd:
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
msg = f"No VJP defined for custom_vjp function {debug_fun.func_name} using defvjp."
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = resolve_kwargs(self.fun, args, kwargs)
debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs,
static_argnums=self.nondiff_argnums)
# TODO(necula): figure out how to construct the debug_bwd args
debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {})
if self.optimize_remat:
fwd = optimize_remat_of_custom_vjp_fwd(
self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums,
self.fun, debug_fun, self.fwd, debug_fwd,
nondiff_argnums=self.nondiff_argnums,
symbolic_zeros=self.symbolic_zeros)
else:
fwd = self.fwd
@ -633,23 +652,27 @@ class custom_vjp(Generic[ReturnValue]):
for i in self.nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums = set(self.nondiff_argnums)
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
args, require_static_args_hashable=False)
f_, dyn_args = argnums_partial(
lu.wrap_init(self.fun, debug_info=debug_fun), dyn_argnums,
args, require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
require_static_args_hashable=False)
bwd = prepend_static_args(lu.wrap_init(self.bwd), static_args)
fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd),
dyn_argnums, args,
require_static_args_hashable=False)
bwd = prepend_static_args(lu.wrap_init(self.bwd, debug_info=debug_bwd),
static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug_fun), args
fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd)
bwd = lu.wrap_init(self.bwd, debug_info=debug_bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.get_aval(x) for x in args_flat]
if config.mutable_array_checks.value:
f_ = _check_primal_refs(f_, self.nondiff_argnums)
f_ = _check_primal_refs(f_, self.nondiff_argnums, f_.debug_info)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(
fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
fwd_, self.nondiff_argnums, self.symbolic_zeros, debug_fun,
debug_fwd, in_tree, out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees,
@ -658,19 +681,27 @@ class custom_vjp(Generic[ReturnValue]):
return tree_unflatten(out_tree, out_flat)
@lu.transformation2
def _check_primal_refs(f, nondiff_argnums, *args):
_check_for_aliased_refs(f, nondiff_argnums, args)
def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
debug_info: core.DebugInfo | None, *args):
_check_for_aliased_refs(f, nondiff_argnums, debug_info, args)
out = f(*args)
_check_for_returned_refs(f, out, 'primal')
return out
def _check_for_aliased_refs(f, nondiff_argnums, args):
def _check_for_aliased_refs(f: Callable,
nondiff_argnums: Sequence[int],
debug: core.DebugInfo | None,
args):
leaves = tree_leaves(args)
refs: dict[int, int] = {}
for i, x in enumerate(leaves):
if (isinstance((a := core.get_aval(x)), AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
if debug is not None:
arg_names = debug.safe_arg_names(len(leaves))
else:
# TODO(necula): drop this branch
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
if arg_names is None:
arg_names = [f'flat index {j}' for j in range(len(leaves))]
raise ValueError(
@ -725,18 +756,24 @@ def _check_for_tracers(x):
raise UnexpectedTracerError(msg)
@partial(lu.transformation_with_aux2, use_eq_store=True)
def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name,
fwd_name, in_tree, maybe_out_type, *args):
def _flatten_fwd(f: Callable, store: lu.EqualStore,
nondiff_argnums: Sequence[int],
symbolic_zeros: bool,
debug_primal: core.DebugInfo | None,
debug_fwd: core.DebugInfo | None,
in_tree: PyTreeDef, maybe_out_type, *args):
primal_name = debug_primal.func_name if debug_primal else str(f)
fwd_name = debug_fwd.func_name if debug_fwd else "<unknown>"
if symbolic_zeros:
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
args = tuple(CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2]))
else:
args = args[::2]
py_args = tree_unflatten(in_tree, args)
if config.mutable_array_checks.value:
_check_for_aliased_refs(f, nondiff_argnums, py_args)
_check_for_aliased_refs(f, nondiff_argnums, debug_primal, py_args)
pair_out = f(*py_args)
if config.mutable_array_checks.value:
_check_for_returned_refs(f, pair_out, 'fwd')
_check_for_returned_refs(f, pair_out, "fwd")
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
@ -790,7 +827,10 @@ def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name,
return (*res, *primals_out)
@lu.transformation2
def _flatten_bwd(f, in_tree, in_avals, out_trees, *args):
def _flatten_bwd(f: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
out_trees: Callable[[], Sequence[PyTreeDef]], *args):
out_tree, res_tree = out_trees()
assert len(args) == res_tree.num_leaves + out_tree.num_leaves
res, cts_out = split_list(args, [res_tree.num_leaves])
@ -1494,7 +1534,9 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")
# simpler, but it would be worth revisiting this.
def optimize_remat_of_custom_vjp_fwd(
fun: Callable[..., ReturnValue],
debug_fun: core.DebugInfo | None,
fwd: Callable[..., tuple[ReturnValue, Any]],
debug_fwd: core.DebugInfo | None,
nondiff_argnums: Sequence[int] = (),
symbolic_zeros: bool = False,
) -> Callable[..., tuple[ReturnValue, Any]]:
@ -1507,8 +1549,7 @@ def optimize_remat_of_custom_vjp_fwd(
def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
# TODO(dfm): This initial logic is duplicated from custom_vjp.__call__
# above and it would be good to consolidate it.
primal_name = getattr(fun, "__name__", str(fun))
fwd_name = getattr(fwd, "__name__", str(fwd))
fwd_name = debug_fwd.func_name if debug_fwd else str(fwd)
# Note: we use `fun` instead of `fwd` here for consistency with
# custom_vjp.__call__ above.
args = resolve_kwargs(fun, args, kwargs)
@ -1516,17 +1557,19 @@ def optimize_remat_of_custom_vjp_fwd(
for i in nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums_ = set(nondiff_argnums)
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_]
f_, dyn_args = argnums_partial(lu.wrap_init(fun), dyn_argnums,
args, require_static_args_hashable=False)
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
f_, dyn_args = argnums_partial(lu.wrap_init(fun, debug_info=debug_fun),
dyn_argnums,
args, require_static_args_hashable=False)
fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd),
dyn_argnums, args,
require_static_args_hashable=False)
else:
f_, dyn_args = lu.wrap_init(fun), args
fwd_ = lu.wrap_init(fwd)
f_, dyn_args = lu.wrap_init(fun, debug_info=debug_fun), args
fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False,
primal_name, fwd_name, in_tree, out_type)
debug_fun, debug_fwd, in_tree, out_type)
flat_fwd = _fix_fwd_args(flat_fwd)
in_avals = [core.get_aval(x) for x in args_flat]

View File

@ -68,7 +68,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
return jvpfun(fun, instantiate, transform_stack), aux
@lu.transformation2
def jvpfun(f, instantiate, transform_stack, primals, tangents):
def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
tag = core.TraceTag()
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
and dtype(t) == float0 else t for t in tangents]
@ -106,7 +106,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
return tuple(consts) + tuple(out_primals)
@lu.transformation2
def jvp_subtrace(f, tag, primals, tangents):
def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents):
with core.take_current_trace() as parent_trace:
trace = JVPTrace(parent_trace, tag)
in_tracers = [maybe_jvp_tracer(trace, x, t)
@ -778,7 +778,8 @@ def linearize_from_jvp(jvp, multiple_results, nonzeros,
out_nz_tracers = [trace.to_jaxpr_tracer(r)
for (r, nz) in zip(out_tangents, out_nzs) if nz]
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers)
# TODO(necula): pass debug_info here
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, None)
def linearized(residuals, *tangents):
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
@ -932,13 +933,15 @@ def traceable(f, store, in_tree, *primals_and_tangents):
return out_flat
def call_transpose(primitive, params, call_jaxpr, args, ct, _):
def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
if isinstance(call_jaxpr, core.ClosedJaxpr):
call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
else:
consts = ()
all_args, in_tree_def = tree_flatten((consts, args, ct))
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, False)
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
debug_info=call_jaxpr.debug_info),
call_jaxpr, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
@ -950,7 +953,7 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _):
res_invars, _ = partition_list(which_lin, call_jaxpr.invars)
new_invars = [*res_invars, *call_jaxpr.outvars]
dbidx_map = {v: core.DBIdx(i) for i, v in enumerate(new_invars)}
in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape))
in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape)) # type: ignore[arg-type]
if type(v.aval) is core.DShapedArray else v.aval, True) for v in new_invars]
fun = lu.annotate(fun, tuple(in_type))
out_flat = primitive.bind(fun, *all_args, **params)
@ -1027,8 +1030,8 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
assert len(jaxpr.in_avals) == len(nonzeros)
debug_info = jaxpr.jaxpr.debug_info
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
debug_info=jaxpr.jaxpr.debug_info)
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]

View File

@ -34,6 +34,7 @@ import warnings
import numpy as np
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
@ -2156,7 +2157,8 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
as `avals_out`."""
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
wrapped_fun = lu.wrap_init(f, params,
debug_info=api_util.debug_info("lower_fun", fun, args, params))
manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else
ctx.jaxpr_eqn_ctx.manager)

View File

@ -236,7 +236,7 @@ class JaxprTrace(Trace['JaxprTracer']):
params, effects, source)
return out_tracer
def process_call(self, primitive, f, tracers, params):
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
tracers = map(self.to_jaxpr_tracer, tracers)
rule = call_partial_eval_rules.get(primitive)
if rule:
@ -254,7 +254,7 @@ class JaxprTrace(Trace['JaxprTracer']):
# which were unknown to the first call (corresponding to in_avals).
# Wrap f to perform the partial evaluation and plumb out aux data.
f_ = trace_to_subjaxpr_nounits_fwd(f, self.tag, False)
f_ = trace_to_subjaxpr_nounits_fwd(f, self.tag, f.debug_info, False)
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), tuple(in_avals))
# Adjust parameters (e.g. donated_invars) for the call to be evaluated now.
@ -330,7 +330,7 @@ class JaxprTrace(Trace['JaxprTracer']):
for ax, aval in zip(unk_in_axes, in_avals)]
# Wrap f to perform partial evaluation and plumb out aux data.
f = trace_to_subjaxpr_nounits2(f, self.tag, False)
f = trace_to_subjaxpr_nounits2(f, self.tag, f.debug_info, False)
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns),
tuple(in_avals_mapped))
# Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk)
@ -429,7 +429,7 @@ class JaxprTrace(Trace['JaxprTracer']):
tracers = map(self.instantiate_const_abstracted, tracers)
# Because we instantiate all tracers, in_knowns is all False.
in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers])
f = trace_to_subjaxpr_nounits(f, self, True)
f = trace_to_subjaxpr_nounits(f, self, True, f.debug_info)
f, aux = partial_eval_wrapper_nounits(f, (*in_knowns,), (*in_avals,))
with core.set_current_trace(self.parent_trace):
out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees,
@ -445,7 +445,7 @@ class JaxprTrace(Trace['JaxprTracer']):
@_memoize
def fwd_jaxpr_thunk(*zeros):
fwd_ = _interleave_fun(fwd, zeros)
fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True)
fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True, fwd_.debug_info)
fwd_, aux = partial_eval_wrapper_nounits(fwd_, (*in_knowns,), (*in_avals,))
out_flat = fwd_.call_wrapped()
out_knowns, out_avals, jaxpr, env = aux()
@ -476,7 +476,10 @@ def partition_pvals(
@lu.transformation_with_aux2
def partial_eval_wrapper_nounits(
f, store, in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue],
f: Callable,
store: lu.Store,
in_knowns: Sequence[bool],
in_avals: Sequence[AbstractValue],
*in_consts: Any):
in_avals_, in_consts_ = iter(in_avals), iter(in_consts)
in_pvals = [PartialVal.known(next(in_consts_)) if known else
@ -566,7 +569,7 @@ def trace_to_jaxpr_nounits(
with core.take_current_trace() as parent_trace:
trace = JaxprTrace(parent_trace, current_name_stack, TraceTag())
with core.ensure_no_leaks(trace):
fun = trace_to_subjaxpr_nounits(fun, trace, instantiate)
fun = trace_to_subjaxpr_nounits(fun, trace, instantiate, fun.debug_info)
with core.set_current_trace(trace):
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
@ -579,10 +582,11 @@ def trace_to_subjaxpr_nounits(
f: Callable,
trace: JaxprTrace,
instantiate: Sequence[bool] | bool,
debug_info: core.DebugInfo | None,
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
f, trace, instantiate, in_pvals)
f, trace, instantiate, in_pvals, debug_info)
out_pvals = [t.pval for t in out_tracers]
del out_tracers
return jaxpr, (out_pvals, out_consts, env)
@ -591,6 +595,7 @@ def trace_to_subjaxpr_nounits(
def trace_to_subjaxpr_nounits2(
f: Callable,
tag: TraceTag,
debug_info: core.DebugInfo | None,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert isinstance(tag, TraceTag)
@ -599,14 +604,15 @@ def trace_to_subjaxpr_nounits2(
with core.take_current_trace() as parent_trace:
trace = JaxprTrace(parent_trace, current_name_stack, tag)
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
f, trace, instantiate, in_pvals)
f, trace, instantiate, in_pvals, debug_info)
out_pvals = [t.pval for t in out_tracers]
del out_tracers
return jaxpr, (out_pvals, out_consts, env)
def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
instantiate: Sequence[bool] | bool,
in_pvals: Sequence[PartialVal]):
in_pvals: Sequence[PartialVal],
debug_info: core.DebugInfo | None):
in_knowns = [pval.is_known() for pval in in_pvals]
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
@ -623,7 +629,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
out_tracers = [trace.instantiate_const(t) if inst else t
for inst, t in zip(instantiate, out_tracers)]
out_tracers_ = [t for t in out_tracers if not t.is_known()]
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_)
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info)
return out_tracers, jaxpr, out_consts, env
# The below variant implements an optimization where residuals which are also
@ -631,8 +637,9 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
# TODO(mattjj): update all callers to use this version, delete other version.
@lu.transformation2
def trace_to_subjaxpr_nounits_fwd(
f,
f: Callable,
tag: TraceTag,
debug_info: core.DebugInfo | None,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
@ -641,7 +648,7 @@ def trace_to_subjaxpr_nounits_fwd(
trace = JaxprTrace(parent_trace, current_name_stack, tag)
with core.set_current_trace(trace):
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
f, trace, instantiate, in_pvals)
f, trace, instantiate, in_pvals, debug_info)
out_pvals = [t.pval for t in out_tracers]
# Which out_consts (aka residuals) are just forwarded inputs? Check obj id.
@ -660,8 +667,9 @@ def trace_to_subjaxpr_nounits_fwd(
# than passed as redundant outputs.
@lu.transformation2
def trace_to_subjaxpr_nounits_fwd2(
f,
f: Callable,
tag: TraceTag,
debug_info: core.DebugInfo | None,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
@ -669,7 +677,7 @@ def trace_to_subjaxpr_nounits_fwd2(
with core.take_current_trace() as parent_trace:
trace = JaxprTrace(parent_trace, current_name_stack, tag)
out_tracers, jaxpr, consts, env = _trace_to_subjaxpr_nounits(
f, trace, instantiate, in_pvals)
f, trace, instantiate, in_pvals, debug_info)
out_pvals = [t.pval for t in out_tracers]
# Which consts (aka residuals) are just forwarded inputs? Check obj id.
@ -743,7 +751,8 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer]
out_tracers: Sequence[JaxprTracer],
debug_info: core.DebugInfo | None,
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
"""Constructs Jaxpr given tracers for inputs and outputs.
@ -819,7 +828,8 @@ def tracers_to_jaxpr(
outvars = map(get_atom, out_tracers) # type: ignore[arg-type]
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
jaxpr = Jaxpr(const_vars, invars, # type: ignore[arg-type]
outvars, eqns, jaxpr_effects)
outvars, eqns, jaxpr_effects,
debug_info)
config.enable_checks.value and core.check_jaxpr(jaxpr)
# del getvar # needed to avoid cyclic-reference closure, apparently!
return jaxpr, const_vals, env_vals
@ -1135,7 +1145,8 @@ def _partial_eval_jaxpr_custom_cached(
staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars,
outs_staged, staged_eqns)
jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars,
outs_staged, staged_eqns, staged_effects)
outs_staged, staged_eqns, staged_effects,
jaxpr.debug_info)
config.enable_checks.value and core.check_jaxpr(jaxpr_staged)
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
@ -1504,7 +1515,8 @@ def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars,
closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns,
closed_jaxpr.jaxpr.effects)
closed_jaxpr.jaxpr.effects,
closed_jaxpr.jaxpr.debug_info)
new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
return new_closed_jaxpr
@ -1666,7 +1678,8 @@ class JaxprStackFrame:
set_states(self.attrs_tracked, self.attrs_inits)
return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)
def to_jaxpr2(self, out_tracers):
def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
debug_info: core.DebugInfo | None):
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
constvars, constvals = unzip2(self.constvar_to_val.items())
@ -1674,10 +1687,10 @@ class JaxprStackFrame:
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns,
jaxpr_effects)
jaxpr_effects, debug_info)
# We can't run check_jaxpr until after we normalize.
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore[assignment]
jaxpr, out_type = _add_implicit_outputs(jaxpr)
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, out_type, constvals
@ -2184,7 +2197,7 @@ def trace_to_jaxpr_dynamic2(
with core.set_current_trace(trace):
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(trace.to_jaxpr_tracer, ans)
jaxpr = trace.frame.to_jaxpr2(out_tracers)
jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info)
del trace, in_tracers, out_tracers, ans
return jaxpr

View File

@ -240,7 +240,8 @@ def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=Fa
def _prune_zeros(ts):
return [t for t in ts if type(t) is not ad_util.Zero]
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
def _make_closed_jaxpr(traceable: lu.WrappedFun,
in_avals: Sequence[core.AbstractValue]):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
return core.ClosedJaxpr(jaxpr, consts)

View File

@ -327,7 +327,7 @@ def _cond_with_per_branch_args(pred,
lambda op: false_fun(op[1]),
(true_operand, false_operand))
def _join_cond_effects(branches: Sequence[core.Jaxpr]) -> effects.Effects:
def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects:
joined_effects = set()
for b in branches:
for eff in b.effects:
@ -337,7 +337,8 @@ def _join_cond_effects(branches: Sequence[core.Jaxpr]) -> effects.Effects:
joined_effects.add(eff)
return joined_effects
def _cond_abstract_eval(*avals, branches, **_):
def _cond_abstract_eval(*avals: core.AbstractValue,
branches: Sequence[core.ClosedJaxpr], **_):
joined_effects = _join_cond_effects(branches)
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
@ -614,10 +615,11 @@ def _merge_branch_residuals(branch_res_avals):
# This function augments branch outputs to agree with the merged residual
# format: each branch is made to return zero-filled values in the places of
# residual outputs that it does not populate.
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
def _join_cond_outputs(jaxprs: Sequence[core.ClosedJaxpr],
all_res_avals, res_aval_indices_per_jaxpr,
num_non_res_outputs):
def augment_jaxpr(jaxpr, res_indices):
@lu.wrap_init
def augment_jaxpr(jaxpr: core.ClosedJaxpr,
res_indices):
def f_aug(*args):
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
@ -625,19 +627,21 @@ def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
return outs + list(aug_residuals)
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
wrapped_f_aug = lu.wrap_init(f_aug, debug_info=jaxpr.jaxpr.debug_info)
return _make_closed_jaxpr(wrapped_f_aug, jaxpr.in_avals)
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
# This function augments branch inputs to agree with the merged residual format:
# each branch is made to accept all residuals, even though it will ignore those
# that it does not read.
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
def _join_cond_pe_staged_jaxpr_inputs(jaxprs: Sequence[core.ClosedJaxpr],
all_res_avals,
res_aval_indices_per_jaxpr):
newvar = core.gensym(suffix='_')
all_res_vars = map(newvar, all_res_avals)
def augment_jaxpr(jaxpr, res_indices):
def augment_jaxpr(jaxpr: core.ClosedJaxpr, res_indices) -> core.ClosedJaxpr:
num_res = len(res_indices)
res_vars = jaxpr.jaxpr.invars[:num_res]
non_res_vars = jaxpr.jaxpr.invars[num_res:]
@ -646,9 +650,9 @@ def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
aug_invars = aug_res_vars + non_res_vars
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
jaxpr.jaxpr.effects)
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
return jaxpr_aug
jaxpr.jaxpr.effects,
jaxpr.jaxpr.debug_info)
return core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
@ -679,8 +683,7 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
# Finally, update parameters and form the new eqn.
new_params = dict(eqn.params, branches=tuple(dce_branches))
new_effects = core.join_effects(*(b.effects for b in dce_branches))
new_effects = _join_cond_effects(dce_branches_)
new_effects = _join_cond_effects(dce_branches)
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
@ -693,10 +696,10 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
return [True, *used_inputs], new_eqn
def _transpose_cond_jaxpr(jaxpr, num_res):
def _transpose_cond_jaxpr(jaxpr: core.ClosedJaxpr,
num_res: int):
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
@lu.wrap_init
def transposed(*args):
res, cts_out = split_list(args, [num_res])
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
@ -705,7 +708,9 @@ def _transpose_cond_jaxpr(jaxpr, num_res):
_, cts_in = split_list(cts_in, [num_res])
return map(ad.instantiate_zeros, cts_in)
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
return _make_closed_jaxpr(lu.wrap_init(transposed,
debug_info=jaxpr.jaxpr.debug_info),
res_avals + jaxpr.out_avals)
def _cond_transpose(cts, *args, branches):
index, *ops = args

View File

@ -70,11 +70,13 @@ for_p.multiple_results = True
### Tracing utilities
def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
state_avals: Sequence[core.AbstractValue]
def _trace_to_jaxpr_with_refs(f: Callable, state_tree: PyTreeDef,
state_avals: Sequence[core.AbstractValue],
debug_info: core.DebugInfo | None,
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
f, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
lu.wrap_init(f, debug_info=debug_info),
treedef_tuple((tree_structure(0), state_tree)))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
f, state_avals)
return jaxpr, consts, out_tree_thunk()
@ -129,12 +131,13 @@ def for_loop(nsteps: int | Sequence[int],
rest_steps, functools.partial(body, i), vals, unroll=unroll)
tree_map(lambda ref, val: ref_set(ref, (), val), refs, vals)
return for_loop(outer_step, wrapped_body, init_state, unroll=unroll)
dbg = api_util.debug_info("for_loop", body, (0, init_state), {})
nsteps, = nsteps
flat_state, state_tree = tree_flatten(init_state)
state_avals = map(state_utils.val_to_ref_aval, flat_state)
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
body, state_tree, [idx_aval, *state_avals])
body, state_tree, [idx_aval, *state_avals], dbg)
if out_tree != tree_structure(None):
raise Exception("`body` should not return anything.")
jaxpr = state_utils.hoist_consts_to_refs(jaxpr, index=1)
@ -212,6 +215,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
tree_map(lambda c_ref, c: ref_set(c_ref, (), c), carry_refs, carry)
tree_map(lambda y_ref, y: ref_set(y_ref, (i,), y), ys_refs, y)
assert isinstance(length, int)
api_util.save_wrapped_fun_sourceinfo(for_body, f)
init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse,
unroll=unroll)
return init, ys
@ -572,7 +576,6 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
newvar = core.gensym()
resvars = map(newvar, res_avals)
@lu.wrap_init
def known(*known_vals):
empty_res = map(ad_util.zeros_like_aval, res_avals)
jaxpr_known_args = [*known_vals, *empty_res]
@ -581,7 +584,8 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
reverse=reverse, which_linear=jaxpr_known_which_linear,
unroll=unroll)
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
known, [v.aval for v in known_invars])
lu.wrap_init(known, debug_info=jaxpr.debug_info),
[v.aval for v in known_invars])
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars],
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
@ -596,14 +600,14 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
which_linear=which_linear_unknown,
unroll=unroll)
@lu.wrap_init
def staged(*res_and_refs):
out_flat = for_p.bind(*res_and_refs, **params_staged)
_, ans = split_list(out_flat, [num_res])
_, ans = partition_list(out_inst, ans)
return ans
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
staged, [v.aval for v in [*resvars, *eqn.invars]])
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info),
[v.aval for v in [*resvars, *eqn.invars]])
assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars)
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
_, outvars = partition_list(out_inst, eqn.outvars)
@ -621,7 +625,7 @@ def _convert_outputs_to_writes(
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
in_avals = [v.aval for v in jaxpr.invars] # [i, *orig_ref_avals]
@lu.wrap_init
def eval_jaxpr(i, *refs):
# We split the refs into the original input refs and the dummy residual
# refs.
@ -641,7 +645,8 @@ def _convert_outputs_to_writes(
v.aval.dtype)) # pytype: disable=attribute-error
for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals])
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
[*in_avals, *res_ref_avals])
assert not consts
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] # pytype: disable=attribute-error
@ -650,7 +655,6 @@ def _convert_inputs_to_reads(
loop_invar_res: Sequence[bool]) -> core.Jaxpr:
assert not jaxpr.constvars, "Jaxpr should not have constvars"
@lu.wrap_init
def eval_jaxpr(i, *refs):
residual_refs, orig_refs = split_list(refs, [num_res])
residual_vals = [r[()] if loop_invar else r[i] for r, loop_invar
@ -667,7 +671,8 @@ def _convert_inputs_to_reads(
for aval, loop_invar in zip(res_val_avals, loop_invar_res)]
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
[i_aval, *res_ref_avals, *orig_ref_avals])
return jaxpr
def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr:
@ -693,7 +698,8 @@ def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr:
ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ())
return []
jaxpr_trans, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
lu.wrap_init(trans, debug_info=jaxpr.debug_info),
[v.aval for v in jaxpr.invars])
return jaxpr_trans
def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear, unroll):
@ -746,8 +752,9 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
flat_state, state_tree = tree_flatten(init_state)
state_avals = map(state_utils.val_to_ref_aval, flat_state)
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))
debug = api_util.debug_info("discharged_for_loop", body, (0, init_state), {})
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
body, state_tree, [idx_aval, *state_avals])
body, state_tree, [idx_aval, *state_avals], debug)
if out_tree != tree_structure(None):
raise Exception("`body` should not return anything.")
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)

View File

@ -511,7 +511,7 @@ def _empty_array(prefix, length_spec, aval):
eval_jaxpr_p = core.Primitive('eval_jaxpr')
eval_jaxpr_p.multiple_results = True
def _stage_jaxpr(trace, *tracers, jaxpr):
def _stage_jaxpr(trace: pe.JaxprTrace, *tracers, jaxpr: core.ClosedJaxpr):
params = dict(call_jaxpr=jaxpr)
return trace.default_process_primitive(core.closed_call_p, tracers, params)
pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr

View File

@ -334,7 +334,9 @@ def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
return x
def _tangent_linear_map(func, params, params_dot, *x):
def _tangent_linear_map(func: Callable, params, params_dot,
debug_info: core.DebugInfo | None,
*x):
"""Compute the tangent of a linear map.
Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
@ -342,7 +344,7 @@ def _tangent_linear_map(func, params, params_dot, *x):
"""
assert any(type(p) is not ad_util.Zero for p in params_dot)
zeros = _map(ad_util.Zero.from_primal_value, x)
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
_, out_tangent = ad.jvp(lu.wrap_init(func, debug_info=debug_info)).call_wrapped(
params + list(x), params_dot + zeros)
return out_tangent
@ -369,7 +371,8 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
rhs = b_dot
else:
matvec_tangents = _tangent_linear_map(
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves)
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec,
jaxprs.matvec.jaxpr.debug_info, *x_leaves)
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)

View File

@ -1847,8 +1847,8 @@ def reduce(operands: Any,
return tree_util.tree_unflatten(out_tree, out)
@cache()
def _reduction_jaxpr(computation, aval):
@lu.wrap_init
def _reduction_jaxpr(computation: Callable,
aval: core.AbstractValue):
def comp(x, y):
result = computation(x, y)
if not (isinstance(result, core.Tracer) or core.valid_jaxtype(result)):
@ -1857,7 +1857,11 @@ def _reduction_jaxpr(computation, aval):
f"Reduction functions should only return an array.\n"
f"Full return value: {result}")
return (result,)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp, (aval, aval))
comp_wrapped = lu.wrap_init(
comp,
debug_info=api_util.debug_info("reduction_jaxpr", computation,
(aval, aval), {}))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp_wrapped, (aval, aval))
if any(isinstance(c, core.Tracer) for c in consts):
raise NotImplementedError(
"Reduction computations can't close over Tracers. Please open an issue "

View File

@ -282,6 +282,10 @@ class DebugInfo(NamedTuple):
return self._replace(result_paths=tuple(self.result_paths()))
return self
@property
def func_name(self) -> str:
return self.func_src_info.split(" ")[0]
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
@ -328,7 +332,13 @@ def wrap_init(f: Callable, params=None, *,
@transformation_with_aux2
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
ans = _fun(*args, **kwargs)
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
result_paths = [keystr(path) for path, _ in generate_key_paths(ans)]
if _store:
# In some instances a lu.WrappedFun is called multiple times, e.g.,
# the bwd function in a custom_vjp
assert _store.val == result_paths, (_store, result_paths)
else:
_store.store(result_paths)
return ans
def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:

View File

@ -2352,7 +2352,8 @@ pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
@lu.cache
def _pjit_transpose_trace(fun, in_avals):
def _pjit_transpose_trace(fun: lu.WrappedFun,
in_avals: Sequence[core.AbstractValue]):
transpose_jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
fun, in_avals)
transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
@ -2360,13 +2361,15 @@ def _pjit_transpose_trace(fun, in_avals):
def _pjit_transpose(cts_in, *primals_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
def prune_type(ty, xs, maybe_zeros):
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
body = lu.wrap_init(ad.closed_backward_pass)
body = lu.wrap_init(ad.closed_backward_pass,
debug_info=jaxpr.jaxpr._debug_info)
body = lu.hashable_partial(body, jaxpr, False)
primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in))
body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef)

View File

@ -79,7 +79,8 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
if isinstance(v.aval, AbstractRef) and d
else v.aval for v, d in zip(jaxpr.invars, should_discharge)]
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr,
should_discharge, consts))
should_discharge, consts),
debug_info=jaxpr.debug_info)
new_jaxpr, _ , new_consts, () = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
return new_jaxpr, new_consts

View File

@ -29,6 +29,7 @@ import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api_util
from jax._src import callback
from jax._src import config
from jax._src import core
@ -58,7 +59,6 @@ from jax._src.lib.mlir.dialects import sdy
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
as_hashable_function, memoize, partition_list,
split_list, subs_list2)
from jax.api_util import flatten_fun_nokwargs, argnums_partial
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
@ -160,9 +160,10 @@ def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs,
@util.wraps(f)
@traceback_util.api_boundary
def wrapped(*args):
fun = lu.wrap_init(f)
fun = lu.wrap_init(f,
debug_info=api_util.debug_info("shard_map", f, args, {}))
args_flat, in_tree = tree_flatten(args)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
try: in_specs_flat = broadcast_prefix(in_specs, args,
is_leaf=lambda x: x is None)
except ValueError:
@ -170,7 +171,7 @@ def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs,
raise e('shard_map in_specs') from None
dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat)
if s is not None)
fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat, False)
fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False)
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
@ -460,13 +461,16 @@ class ShardMapPrimitive(core.Primitive):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, fun_and_args, params):
fun: lu.WrappedFun
fun, *args = fun_and_args
return trace.process_shard_map(shard_map_p, fun, args, **params)
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ())
jaxpr: core.Jaxpr = new_params.pop('jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr,
debug_info=jaxpr.debug_info),
jaxpr, ())
axes = new_params.pop('out_names')
new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params
@ -1511,7 +1515,8 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
ad.JVPTrace.process_shard_map = _shard_map_jvp
def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
f: lu.WrappedFun, tracers, mesh, in_names,
out_names_thunk, check_rep, rewrite, auto):
tracers = map(trace.to_jaxpr_tracer, tracers)
in_pvals = [t.pval for t in tracers]
@ -1519,7 +1524,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
all_names = _all_mesh_names_except_spmd(mesh, trace)
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, False)
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False)
f = _promote_scalar_residuals(f)
f_known, aux = pe.partial_eval_wrapper_nounits(
f, (*in_knowns,), (*in_avals_sharded,))
@ -1545,7 +1550,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
res_names = [known_in_names[f1] if f1 is not None else
known_out_names_[f2] if f2 is not None else
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,)
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment]
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.to_jaxpr_tracer, env)
unk_arg_tracers = [t for t in tracers if not t.is_known()]
@ -1628,7 +1633,7 @@ def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs):
return residuals + primals
@lu.transformation2
def _promote_scalar_residuals(f, *args, **kwargs):
def _promote_scalar_residuals(f: Callable, *args, **kwargs):
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs)
which = [f1 is None and f2 is None and not v.aval.shape
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
@ -1686,7 +1691,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
return out
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree)
fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree)
new_in_names = \
[n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \

View File

@ -9520,8 +9520,11 @@ class CustomVJPTest(jtu.JaxTestCase):
def fwd(x):
return np.array([2.0])*x*x/np.array([1.0]), (x,)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
x = jnp.linspace(0, 5.0, 10)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE
self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed
@ -9530,8 +9533,10 @@ class CustomVJPTest(jtu.JaxTestCase):
return (np.array([1.0])*x)[0]
def fwd(x):
return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
x = jnp.linspace(0, 5.0, 10)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x)
self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x)
@ -9540,11 +9545,15 @@ class CustomVJPTest(jtu.JaxTestCase):
return x
def fwd(x):
return x*x, (x,)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
x = jnp.linspace(0, 5.0, 10)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
def g(x):
return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x)
x = jnp.linspace(0, 5.0, 10)
self.assertAllClose(jax.jit(g)(x)[0], x*x)
self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x)
@ -9553,7 +9562,10 @@ class CustomVJPTest(jtu.JaxTestCase):
return x**2
def fwd_(x):
return x*x, (x,)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}),
fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {}))
calc = jax.jvp(fwd, (3.2,), (1.0,))
expected = jax.jvp(fwd_, (3.2,), (1.0,))
self.assertAllClose(calc, expected)

View File

@ -31,9 +31,13 @@ import jax.custom_transpose
from jax.experimental import checkify
import jax.experimental.custom_dce
from jax.experimental import pallas as pl
from jax.experimental.shard_map import shard_map
import jax.numpy as jnp
import jax.scipy as jsp
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax._src import api_util
from jax._src.ad_checkpoint import saved_residuals
from jax._src import config
@ -159,14 +163,17 @@ class DebugInfoTest(jtu.JaxTestCase):
t_debug_info = _debug_info_to_string(t._debug_info)
if check_tracer_arg_name:
msg = str(exc)
m = re.match(r".* while tracing the function (.+) for (.+)\. .* depends on the value of the argument ([^\n]+)\.",
m = re.match(r".* while tracing the function (.+) for ([^.]+)\.",
msg,
re.DOTALL)
self.assertIsNotNone(m, msg)
self.assertEqual(t._debug_info.func_src_info, m.group(1))
self.assertEqual(t._debug_info.traced_for, m.group(2))
m = re.match(r".* depends on the value of the argument ([^\n]+)\.",
msg,
re.DOTALL)
found_tracer_debug_infos.append(
f"{t_debug_info}, from {m.group(3)}")
f"{t_debug_info}, from {m.group(1) if m else None}")
else:
found_tracer_debug_infos.append(t_debug_info)
else:
@ -804,7 +811,8 @@ class DebugInfoTest(jtu.JaxTestCase):
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=",
"None", # TODO(necula): missing debug info
# TODO(necula): arg_names?
"traced_for=jit, fun=my_f, arg_names=None,None,None,None, result_paths=['a'],['b'][0][0]",
],
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
@ -837,8 +845,8 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=jit, fun=<lambda>, arg_names=x,y,res_ct, result_paths=[0],[1]",
# TODO(necula): result_paths
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
# TODO(necula): missing debug info
"None",
# TODO(necula): arg_names
"traced_for=jit, fun=my_g, arg_names=None,None,u,v, result_paths=['c'],['d']",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
@ -854,6 +862,138 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"),
])
def test_custom_jvp(self):
tracer_spy = TracerSpy()
@jax.custom_jvp
def my_fun(x, y, c=1.):
tracer_spy.append(y)
return c * (x + y)
def my_jvp(primals, tangents):
x, y, c = primals
t_x, t_y, t_c = tangents
tracer_spy.append(t_y)
return my_fun(x, y, c), t_c
my_fun.defjvp(my_jvp)
def top_f(x, y):
return jnp.square(my_fun(x, y, c=2.)).sum()
self._check_tracers_and_jaxprs(
jax.jit(lambda a: jax.jvp(top_f, (a, a),
(jnp.ones_like(a), jnp.ones_like(a)))),
42.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=a, result_paths=[0],[1]",
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=a, from None",
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, from y",
])
def test_custom_jvp_nondiff_args(self):
tracer_spy = TracerSpy()
def top_f(xy):
tracer_spy.append(xy[0])
@functools.partial(jax.custom_jvp, nondiff_argnums=(0,))
def my_g(h, xy):
x, y = xy
tracer_spy.append(x)
return h(x)
@my_g.defjvp
def my_g_jvp(h, primals, tangents):
(x, y), = primals
(xt, yt), = tangents
tracer_spy.append(xt)
return my_g(h, (x, y)), 2. * xt
h = lambda y: xy[0] + y # capture x
return my_g(h, xy)
self._check_tracers_and_jaxprs(
jax.jit(lambda a, b: jax.jvp(top_f, ((a, b),),
((jnp.ones_like(a), jnp.ones_like(b)),))),
42., 43.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
# TODO(necula): arg_names
"traced_for=jit, fun=<lambda>, arg_names=None,a,b, result_paths=[0],[1]",
"traced_for=custom_jvp fun, fun=my_g, arg_names=None,xy[0],xy[1], result_paths=",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=custom_jvp fun, fun=my_g, arg_names=xy[0],xy[1], from xy[0]",
# TODO(necula): from None
"traced_for=jit, fun=<lambda>, arg_names=a,b, from None",
"None", # TODO(necula): None
])
def test_custom_vjp(self):
tracer_spy = TracerSpy()
@jax.custom_vjp
def my_f(x):
tracer_spy.append(x["a"])
return {"b": jnp.sin(x["a"])}
def my_f_fwd(x):
tracer_spy.append(x["a"])
return my_f(x), {"r": jnp.cos(x["a"])}
def my_f_bwd(res, g):
tracer_spy.append(g["b"])
cos_x = res["r"]
return ({"a": 2 * cos_x * g["b"]},)
my_f.defvjp(my_f_fwd, my_f_bwd)
def to_diff(x):
return my_f(x)["b"]
self._check_tracers_and_jaxprs(
jax.jit(jax.grad(to_diff)),
{"a" : 3.},
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=['a']",
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=['b']",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']",
# TODO(necula): from None
"traced_for=jit, fun=to_diff, arg_names=x['a'], from None",
"traced_for=jit, fun=to_diff, arg_names=x['a'], from x['a']",
])
def test_custom_vjp_nondiff_args(self):
tracer_spy = TracerSpy()
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
def app(f, xy):
tracer_spy.append(xy[0])
return f(xy)
def app_fwd(f, xy):
tracer_spy.append(xy[0])
return app(f, xy), jnp.cos(xy[0])
def app_rev(f, cos_x0, g):
tracer_spy.append(cos_x0)
tracer_spy.append(g)
return ((cos_x0 * g, cos_x0),)
app.defvjp(app_fwd, app_rev)
self._check_tracers_and_jaxprs(
jax.jit(jax.grad(lambda xy: app(lambda x: 2 * x[0], xy))),
(3., 3.),
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], result_paths=[0],[1]",
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], from xy[0]",
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]",
# TODO(necula): from None
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], from None",
])
def test_vmap_of_nested_jit(self):
tracer_spy = TracerSpy()
@ -995,8 +1135,12 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
'traced_for=jit, fun=my_f, arg_names=x,y, result_paths=',
# TODO(necula): some Jaxprs without debug info
'None',
# TODO(necula): arg_names? result_paths?
"traced_for=cond, fun=my_true_branch, arg_names=None, result_paths=,",
"traced_for=cond, fun=my_false_branch, arg_names=None, result_paths=,",
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=[0],[1]",
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=[0],[1]",
"traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=,",
],
expected_tracer_debug_infos=[
'traced_for=cond, fun=my_true_branch, arg_names=a,b',
@ -1020,7 +1164,9 @@ class DebugInfoTest(jtu.JaxTestCase):
@jax.jit
def my_f(x, as_):
tracer_spy.append(x)
return jax.remat(lambda *args: for_loop.scan(f, *args))(c, as_)
def to_remat(a, b):
return for_loop.scan(f, a, b)
return jax.remat(to_remat)(c, as_)
def the_grad(c, as_):
tracer_spy.append(c)
@ -1035,13 +1181,21 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]",
# TODO(necula): bad result paths
"traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,",
'None', # TODO(necula): some Jaxprs without debug info
# TODO(necula): arg_names?
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=",
"traced_for=for_loop, fun=f, arg_names=None,None,None, result_paths=,",
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None, result_paths=,",
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None, result_paths=",
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None,None,None,None,None, result_paths=,",
"traced_for=checkpoint / remat, fun=to_remat, arg_names=None,None,None, result_paths=,",
"traced_for=jit, fun=my_f, arg_names=None,None,x,as_, result_paths=",
],
expected_tracer_debug_infos=[
"traced_for=jit, fun=the_grad, arg_names=c,as_",
"traced_for=scan, fun=f, arg_names=c,a",
"traced_for=jit, fun=my_f, arg_names=x,as_",
'None', # TODO(necula): some missing debug info
# TODO(necula): arg_names
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2]",
],
expected_lowering_lines=[
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"c\"\)"),
@ -1321,12 +1475,44 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
# TODO(necula): some Jaxprs without debug info
"None"],
# TODO(necula): arg_names?
"traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=",
],
expected_tracer_debug_infos=[
"traced_for=checkpoint / remat, fun=my_g, arg_names=y"
])
def test_remat_shard_map(self):
tracer_spy = TracerSpy()
if len(jax.devices()) < 2:
self.skipTest("requires at least 2 devices")
# this tests remat-of-shmap
mesh = Mesh(np.array(jax.devices()[:2]), ('x',))
# check param updating is handled
@jax.remat
@functools.partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def my_f(x):
tracer_spy.append(x)
return jnp.sin(jnp.sin(x))
self._check_tracers_and_jaxprs(
jax.jit(jax.grad(lambda x: my_f(x).sum())),
jnp.arange(2, dtype=np.float32),
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
# TODO(necula): arg_names
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
"traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=",
"traced_for=shard_map, fun=my_f, arg_names=x, result_paths=",
"traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=",
"None", # TODO(necula): missing
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"None" # TODO(necula): missing
])
def test_remat_saved_residuals(self):
@functools.partial(jax.remat,
static_argnums=(1,),
@ -1359,7 +1545,7 @@ class DebugInfoTest(jtu.JaxTestCase):
# TODO(necula): this should not be pointing into the JAX internals
re.compile(r"traced_for=jit, fun=checked_fun at .*jax/_src/checkify.py:.*, arg_names=args\[0\]"),
re.compile(r"traced_for=jit, fun=argsort at .*numpy/lax_numpy.py:.*, arg_names=a, result_paths="),
"None", # TODO(necula): missing tracer debug info
"traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=[0]",
],
expected_tracer_debug_infos=[
"traced_for=pmap, fun=my_f, arg_names=my_x",
@ -1495,7 +1681,8 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0],[1]",
"None", # TODO(necula): there are missing Jaxpr debug info
# TODO(necula): internal function?
re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*/control_flow/solves.py:.*, arg_names=args\[0\], result_paths=\[0\]"),
],
expected_tracer_debug_infos=[
"traced_for=custom_root, fun=my_f, arg_names=x",