[better_errors] Continue adding debug info to Jaxprs (step 3)

This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This commit is contained in:
George Necula 2025-02-05 19:17:47 +02:00
parent 0fb278a0b9
commit 904b74860c
17 changed files with 234 additions and 99 deletions

View File

@ -573,7 +573,11 @@ def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0,
@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f = lu.wrap_init(
fun, kwargs,
debug_info=debug_info(
"jacfwd", fun, args, kwargs,
static_argnums=(argnums,) if isinstance(argnums, int) else argnums))
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
@ -661,7 +665,11 @@ def jacrev(fun: Callable, argnums: int | Sequence[int] = 0,
@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f = lu.wrap_init(
fun, kwargs,
debug_info=debug_info(
"jacrev", fun, args, kwargs,
static_argnums=(argnums,) if isinstance(argnums, int) else argnums))
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)

View File

@ -1065,18 +1065,20 @@ def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
return jvp
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
symbolic_zeros):
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals,
fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk, num_consts,
bwd: lu.WrappedFun, out_trees,
symbolic_zeros: bool):
err_vals, err_tree = jtu.tree_flatten(in_err)
num_errs = err_tree.num_leaves
checkified_fun = lu.wrap_init(
functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
fun_jaxpr.consts, enabled_errors, err_tree))
fun_jaxpr.consts, enabled_errors, err_tree),
debug_info=fun_jaxpr.jaxpr.debug_info)
checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk(
checkified_fun)
@lu.wrap_init
def checkified_fwd(*args):
# TODO(lenamartens, sharadmv): why not checkify here?
xs, zeros = args[::2], args[1::2]
@ -1085,10 +1087,15 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
xs_without_consts = xs[num_consts:]
return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)
bwd_ = lambda *args: (*(None,)*num_errs, *bwd(*args))
checkified_fwd, fwd_out_tree = flatten_fun_output(checkified_fwd)
# TODO(necula): the fwd result_paths are not quite the same as fun_jaxpr
checkified_fwd_wrapped = lu.wrap_init(checkified_fwd,
debug_info=fun_jaxpr.jaxpr.debug_info)
bwd_ = lu.wrap_init(lambda *args: (*(None,)*num_errs, *bwd.call_wrapped(*args)),
debug_info=bwd.debug_info)
checkified_fwd_wrapped, fwd_out_tree = flatten_fun_output(checkified_fwd_wrapped)
all_outs = custom_derivatives.custom_vjp_call_p.bind(
checkified_fun, checkified_fwd, bwd_, *err_vals, *in_vals, out_trees=out_trees,
checkified_fun, checkified_fwd_wrapped,
bwd_, *err_vals, *in_vals, out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
if fst:

View File

@ -392,7 +392,7 @@ class JaxprEqn:
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
ctx=None):
ctx=None) -> JaxprEqn:
source_info = source_info or source_info_util.new_source_info()
ctx = ctx or JaxprEqnContext(
compute_on.current_compute_type(),

View File

@ -374,15 +374,16 @@ class CustomJVPCallPrimitive(core.Primitive):
def get_bind_params(self, params):
new_params = dict(params)
call_jaxpr = new_params.pop('call_jaxpr')
num_consts = new_params.pop('num_consts')
call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr')
num_consts: int = new_params.pop('num_consts')
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr))
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk)
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr),
debug_info=call_jaxpr.jaxpr.debug_info)
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
return [fun, jvp], new_params
def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable) -> lu.WrappedFun:
@lu.wrap_init
def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable,
debug_info: core.DebugInfo | None) -> lu.WrappedFun:
def jvp(*xs):
n, ragged = divmod(len(xs), 2)
assert not ragged
@ -398,7 +399,7 @@ def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable) -> lu.WrappedFun:
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
return [*out_primals, *out_tangents]
return jvp
return lu.wrap_init(jvp, debug_info=debug_info)
effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)
@ -435,8 +436,9 @@ def _custom_jvp_call_transpose(params, jaxpr, args, ct, _):
ad.primitive_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose
@weakref_lru_cache
def _cached_closed_call_dce_instantiate(jaxpr_, used_outputs: tuple[bool, ...]
) -> tuple[core.ClosedJaxpr, list[bool]]:
def _cached_closed_call_dce_instantiate(jaxpr_: core.ClosedJaxpr,
used_outputs: tuple[bool, ...]
) -> tuple[core.ClosedJaxpr, list[bool]]:
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts
new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs, True)
return core.ClosedJaxpr(new_jaxpr, consts), used_inputs
@ -673,7 +675,7 @@ class custom_vjp(Generic[ReturnValue]):
flat_fwd, out_trees = _flatten_fwd(
fwd_, self.nondiff_argnums, self.symbolic_zeros, debug_fun,
debug_fwd, in_tree, out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees,
symbolic_zeros=self.symbolic_zeros)
@ -940,7 +942,9 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
def _custom_vjp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
num_consts: int, bwd: lu.WrappedFun,
out_trees: Callable[[], Sequence[PyTreeDef]],
symbolic_zeros: bool):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
if any(type(t) is not Zero for t in consts_dot):
@ -963,7 +967,8 @@ def _custom_vjp_call_jaxpr_vmap(
axis_data, args, in_dims, *,
fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
num_consts: int, bwd: lu.WrappedFun,
out_trees: Callable, symbolic_zeros: bool):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
in_batched = [d is not not_mapped for d in in_dims]
@ -1000,12 +1005,13 @@ def _custom_vjp_call_jaxpr_dce(
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
fun_jaxpr = eqn.params["fun_jaxpr"]
fun_jaxpr: core.ClosedJaxpr = eqn.params["fun_jaxpr"]
fwd_jaxpr_thunk = eqn.params["fwd_jaxpr_thunk"]
bwd = eqn.params["bwd"]
out_trees = eqn.params["out_trees"]
symbolic_zeros = eqn.params["symbolic_zeros"]
bwd: lu.WrappedFun = eqn.params["bwd"]
out_trees: Callable[[], Sequence[PyTreeDef]] = eqn.params["out_trees"]
symbolic_zeros: bool = eqn.params["symbolic_zeros"]
dce_fun_jaxpr: core.ClosedJaxpr
used_ins: Sequence[bool]
dce_fun_jaxpr, used_ins = _cached_closed_call_dce_instantiate(
fun_jaxpr, tuple(used_outs))
assert all(used_ins)
@ -1019,7 +1025,6 @@ def _custom_vjp_call_jaxpr_dce(
fwd_jaxpr, (True,) * num_res + tuple(used_outs))
return dce_fwd_jaxpr.jaxpr, dce_fwd_jaxpr.consts
@lu.wrap_init
def dce_bwd(*args):
_, res_tree = out_trees()
res, cts = split_list(args, [res_tree.num_leaves])
@ -1035,19 +1040,21 @@ def _custom_vjp_call_jaxpr_dce(
else:
all_cts.append(zeros_like_aval(ct_aval))
assert next(cts_, None) is None
return bwd(*res, *all_cts)
return bwd.call_wrapped(*res, *all_cts)
dce_bwd_wrapped = lu.wrap_init(dce_bwd,
debug_info=bwd.debug_info)
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
new_params = dict(
eqn.params,
fun_jaxpr=dce_fun_jaxpr,
fwd_jaxpr_thunk=dce_fwd_jaxpr_thunk,
bwd=dce_bwd.call_wrapped,
bwd=dce_bwd_wrapped,
)
new_eqn = pe.new_jaxpr_eqn(
eqn.invars, outvars, eqn.primitive, new_params, dce_fun_jaxpr.effects,
eqn.source_info, eqn.ctx)
return used_ins, new_eqn
return list(used_ins), new_eqn
pe.dce_rules[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_dce
xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)
@ -1125,7 +1132,9 @@ def custom_gradient(fun):
def fwd(*args, **kwargs):
ans, rule = fun(*args, **kwargs)
ans_flat, out_tree = tree_flatten((ans,))
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
debug_fwd = debug_info("custom_gradient fwd", rule, (ans,), {})
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule,
debug_info=debug_fwd), out_tree)
ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
@ -1224,10 +1233,11 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
"""
flat_args, in_tree = tree_flatten(example_args)
in_avals = tuple(map(core.get_aval, flat_args))
debug = debug_info("closure_convert", fun, example_args, {})
if config.check_tracer_leaks.value:
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals, debug)
else:
return _closure_convert_for_avals(fun, in_tree, in_avals)
return _closure_convert_for_avals(fun, in_tree, in_avals, debug)
def _maybe_perturbed(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value
@ -1251,8 +1261,10 @@ def _maybe_perturbed(x: Any) -> bool:
return True # We can't be sure!
@cache()
def _closure_convert_for_avals(fun, in_tree, in_avals):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
def _closure_convert_for_avals(fun, in_tree, in_avals,
debug_info: core.DebugInfo):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun, debug_info=debug_info),
in_tree)
jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
out_tree = out_tree()

View File

@ -25,7 +25,7 @@ from jax._src import config
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax.tree_util import (tree_flatten, tree_unflatten,
register_pytree_node, Partial)
register_pytree_node, Partial, PyTreeDef)
from jax._src import core
from jax._src import source_info_util
from jax._src.ad_util import (
@ -644,26 +644,28 @@ class LinearizeTrace(Trace):
return [maybe_linearize_tracer(self, x, nz, t)
for x, nz, t in zip(primals_out, tangent_nzs_out, tangents_out)]
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
def process_custom_vjp_call(self, prim, fun, fwd,
bwd: lu.WrappedFun, tracers,
out_trees: Callable[[], Sequence[PyTreeDef]],
symbolic_zeros: bool):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
if all(type(t) is Zero for t in tangents_in):
return prim.bind_with_trace(self.parent_trace,
(fun, fwd, bwd, *primals_in),
dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)]
fwd_in = [x for pair in fwd_in for x in pair] # flatten
fwd_in_flat = [x for pair in fwd_in for x in pair] # flatten
with core.set_current_trace(self.parent_trace):
res_and_primals_out = fwd.call_wrapped(*fwd_in)
res_and_primals_out = fwd.call_wrapped(*fwd_in_flat)
_, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
tangents_in = map(instantiate_zeros, tangents_in)
tangents_in_zeros = map(instantiate_zeros, tangents_in)
with core.set_current_trace(self.tangent_trace):
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
*res, *tangents_in_zeros, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)
@ -975,9 +977,13 @@ def nonzero_outputs(f, store, *args, **kwargs):
store.store([type(r) is not Zero for r in results])
return results
def map_transpose(primitive, params, call_jaxpr, args, ct, _):
def map_transpose(primitive: core.Primitive, params,
call_jaxpr: core.Jaxpr, args, ct, _):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, False)
# TODO(necula): use the right debug_info for the backwards pass
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
debug_info=call_jaxpr.debug_info),
call_jaxpr, False)
fun, nz_arg_cts = nonzero_outputs(fun)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
@ -1056,12 +1062,24 @@ def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents):
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
new_debug_info = jaxpr.jaxpr.debug_info
if new_debug_info is not None:
new_arg_names = tuple(_perm(primals_in, tangents_in,
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
new_result_paths = tuple(_perm(primals_out, tangents_out,
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
new_debug_info = new_debug_info._replace(
arg_names=new_arg_names,
result_paths=new_result_paths,
)
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
new_invars, new_outvars, jaxpr.jaxpr.eqns,
jaxpr.jaxpr.effects)
jaxpr.jaxpr.effects,
new_debug_info)
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
def _perm(primal_counts, tangent_counts, lst):
def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int],
lst: Sequence[Any]) -> Sequence[Any]:
n = sum(primal_counts)
primals, tangents = lst[:n], lst[n:]
primal_groups = split_list(primals, primal_counts[:-1])
@ -1082,14 +1100,15 @@ def raise_custom_vjp_error_on_jvp(*_, **__):
"function.")
custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp)
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals,
def _custom_lin_transpose(cts_out, *invals, num_res,
bwd: lu.WrappedFun, out_avals,
symbolic_zeros):
res, _ = split_list(invals, [num_res])
if symbolic_zeros:
cts_out = map(replace_internal_symbolic_zeros, cts_out)
else:
cts_out = map(instantiate_zeros, cts_out)
cts_in = bwd(*res, *cts_out)
cts_in = bwd.call_wrapped(*res, *cts_out)
cts_in = map(replace_rule_output_symbolic_zeros, cts_in)
return [None] * num_res + list(cts_in)
primitive_transposes[custom_lin_p] = _custom_lin_transpose

View File

@ -799,8 +799,11 @@ def batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
return _batch_jaxpr_axes(closed_jaxpr, axis_data, tuple(in_axes), tuple(out_axes_dest))
@weakref_lru_cache
def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
def _batch_jaxpr_axes(closed_jaxpr: core.ClosedJaxpr,
axis_data: AxisData,
in_axes: Sequence[int], out_axes_dest: Sequence[int]):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
debug_info=closed_jaxpr.jaxpr.debug_info)
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
@ -896,7 +899,10 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
store.store(out_dims * 2)
return out_primals + out_tangents
def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
def batch_custom_vjp_bwd(bwd: lu.WrappedFun, tag: core.TraceTag,
axis_data: AxisData,
in_dims: Callable[[], Sequence[int | None]],
out_dim_dests: Sequence[int | None]) -> lu.WrappedFun:
axis_size = axis_data.size
axis_name = axis_data.name
mesh_axis = axis_data.explicit_mesh_axis
@ -907,11 +913,11 @@ def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
for x, dim in zip(args, in_dims_)]
in_dims_ = [None if type(x) is SymbolicZero else d
for x, d in zip(args, in_dims_)]
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd), tag, axis_data, in_dims_)
bwd_, out_dims_thunk = batch_subtrace(bwd, tag, axis_data, in_dims_)
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, mesh_axis,
out_dims_thunk, out_dim_dests)
return bwd_.call_wrapped(*args)
return new_bwd
return lu.wrap_init(new_bwd, debug_info=bwd.debug_info)
@lu.transformation2
def _match_axes_and_sum(f, axis_size, axis_name, mesh_axis, out_dims_thunk,

View File

@ -2022,8 +2022,11 @@ class DynamicJaxprTrace(core.Trace):
self.frame.add_eqn(eqn)
return out_tracers
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
def process_custom_vjp_call(self, prim: core.Primitive,
fun: lu.WrappedFun,
fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers,
out_trees: Callable[[], Sequence[PyTreeDef]],
symbolic_zeros: bool):
tracers = map(self.to_jaxpr_tracer, tracers)
in_avals = [t.aval for t in tracers]
fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals)
@ -2041,7 +2044,8 @@ class DynamicJaxprTrace(core.Trace):
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts))
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
prim.initial_style, # type: ignore[attribute-error]
dict(fun_jaxpr=closed_fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_from_zeros,
num_consts=len(consts),

View File

@ -586,8 +586,10 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
for p, nz in zip(primals_out, nonzeros_out)]
return primals_out, tangents_out
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll, _split_transpose):
def _scan_partial_eval(trace, *tracers, reverse: bool,
length: int, num_consts: int, num_carry: int,
jaxpr: core.ClosedJaxpr, linear: Sequence[bool],
unroll: int, _split_transpose: bool):
num_ys = len(jaxpr.out_avals) - num_carry
unknowns = [not t.pval.is_known() for t in tracers]
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
@ -612,8 +614,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
del res_avals, carry_uk_out
# Instantiate those inputs which must be treated as unknown from the fixpoint.
tracers = [trace.instantiate_const(t) if uk else t
for t, uk in zip(tracers, unknowns)]
tracers = tuple(trace.instantiate_const(t) if uk else t
for t, uk in zip(tracers, unknowns))
# The residual inputs and outputs of the jaxprs produced haven't yet been
# adapted to the scan calling convention; in particular, jaxpr_known has its
@ -638,7 +640,9 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
for aval in jaxpr_known.in_avals[len(const_pvals):]]
with source_info_util.reset_name_stack():
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals,
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known),
debug_info=jaxpr_known.jaxpr.debug_info),
const_pvals + other_pvals,
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
@ -880,8 +884,9 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
@weakref_lru_cache
def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2,
ct_ys_is_zeros):
def _transpose_scan_jaxpr(jaxpr: core.ClosedJaxpr,
num_res1: int, num_c: int, num_res2: int,
ct_ys_is_zeros: Sequence[bool]):
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
# TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals
# if an axis isn't reduced
@ -896,7 +901,6 @@ def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2,
aval for aval, is_zero in zip(b_ys_avals, ct_ys_is_zeros) if not is_zero
]
@lu.wrap_init
def transposed(*res1_cbar_bbar_res2):
res1, c_bar, b_bar, ys_bar_stripped, res2 = split_list(
res1_cbar_bbar_res2,
@ -915,9 +919,14 @@ def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2,
a_bar = _map(ad.instantiate_zeros, a_bar)
c_bar = _map(ad.instantiate_zeros, _map(ad.add_tangents, c_bar, new_c_bar))
return c_bar + a_bar
# TODO(necula): fix arg names and results for transposed
transposed_wrapped = lu.wrap_init(transposed,
debug_info=jaxpr.jaxpr.debug_info)
return _make_closed_jaxpr_attrs(
transposed, tuple(res1_avals + c_avals + b_carry_avals +
b_ys_avals_stripped + res2_avals))
transposed_wrapped,
tuple(res1_avals + c_avals + b_carry_avals +
b_ys_avals_stripped + res2_avals))
def _scan_batching_rule(axis_data, args,

View File

@ -1923,6 +1923,8 @@ def reduce(operands: Any,
is undefined.
"""
flat_operands, operand_tree = tree_util.tree_flatten(operands)
comp_debug = api_util.debug_info("reduce comp", computation,
(init_values, init_values), {})
flat_init_values, init_value_tree = tree_util.tree_flatten(init_values)
if operand_tree != init_value_tree:
raise ValueError('Operands must have the same tree structure as init_values:'
@ -1939,7 +1941,7 @@ def reduce(operands: Any,
else:
flat_init_avals = safe_map(core.get_aval, flat_init_values)
closed_jaxpr, out_tree = _variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree)
computation, comp_debug, tuple(flat_init_avals), init_value_tree)
out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation,
jaxpr=closed_jaxpr, dimensions=tuple(dimensions))
return tree_util.tree_unflatten(out_tree, out)
@ -1967,10 +1969,13 @@ def _reduction_jaxpr(computation: Callable,
return jaxpr, tuple(consts)
@cache()
def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
def _variadic_reduction_jaxpr(computation: Callable[[Any, Any], Any],
debug_info: core.DebugInfo,
flat_avals,
aval_tree: tree_util.PyTreeDef):
avals = tree_util.tree_unflatten(aval_tree, flat_avals)
flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals))
comp = lu.wrap_init(computation)
comp = lu.wrap_init(computation, debug_info=debug_info)
flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
if any(isinstance(c, core.Tracer) for c in consts):
@ -5921,7 +5926,7 @@ def _argminmax_dtype_rule(operand, *, axes, index_dtype):
class _ArgMinMaxReducer:
def __init__(self, value_comparator):
def __init__(self, value_comparator: Callable[[Any, Any], Any]):
self._value_comparator = value_comparator
def __repr__(self):

View File

@ -19,6 +19,7 @@ from functools import partial
import warnings
from jax import tree_util
from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
@ -55,6 +56,8 @@ def _reduce_window(
window_dilation: Sequence[int] | None = None,
):
flat_operands, operand_tree = tree_util.tree_flatten(operand)
comp_debug = api_util.debug_info("reduce_window comp", computation,
(init_value, init_value), {})
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
if operand_tree != init_value_tree:
raise ValueError(
@ -88,7 +91,7 @@ def _reduce_window(
else:
flat_init_avals = map(core.get_aval, flat_init_values)
jaxpr, out_tree = lax._variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree
computation, comp_debug, tuple(flat_init_avals), init_value_tree
)
if operand_tree != out_tree:
raise ValueError(

View File

@ -326,7 +326,7 @@ def resolve_physical_types(jaxpr: jax_core.Jaxpr, consts: Sequence[Any]):
interp_fun = partial(
eval_jaxpr_recursive, jaxpr, consts,
recurse_hop_rule=resolve_physical_types)
wrapped = lu.wrap_init(interp_fun)
wrapped = lu.wrap_init(interp_fun, debug_info=jaxpr.debug_info)
new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic(
wrapped, kernel_avals)
return new_jaxpr, new_consts

View File

@ -30,7 +30,6 @@ package(
py_library(
name = "jax2tf",
srcs = ["__init__.py"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [":jax2tf_internal"],
)
@ -42,7 +41,6 @@ py_library(
"impl_no_xla.py",
"jax2tf.py",
],
srcs_version = "PY3",
# TODO: b/255503696: enable pytype
tags = ["pytype_unchecked_annotations"],
visibility = jax_visibility("jax2tf_internal"),

View File

@ -24,7 +24,6 @@ package(
py_library(
name = "back_compat_testdata",
srcs = glob(["*.py"]),
srcs_version = "PY3",
deps = [
"//third_party/py/numpy",
"//third_party/py/typing_extensions",

View File

@ -28,7 +28,6 @@ package(
py_library(
name = "flax_models",
srcs = glob(["*.py"]),
srcs_version = "PY3",
deps = [
"//jax",
"//third_party/py/flax:core",

View File

@ -1642,8 +1642,8 @@ def _promote_scalar_residuals(f: Callable, *args, **kwargs):
for x in out_consts]
return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
def _promote_scalar_residuals_jaxpr(jaxpr, which):
@lu.wrap_init
def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr,
which: Sequence[bool]):
def fun(*res_and_args):
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
@ -1651,7 +1651,8 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which):
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
for v, w in zip(jaxpr.constvars, which)]
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals)
return jaxpr
@ -1663,20 +1664,20 @@ def _unmentioned2(mesh: Mesh, names: AxisNames,
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
def _shard_map_transpose(out_cts, *args,
jaxpr: core.Jaxpr, mesh, in_names, out_names,
check_rep, rewrite, auto):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite or dtypes.dtype(x) == dtypes.float0
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
for ns, x in zip(out_names, out_cts)]
args = [x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
for ns, x in zip(in_names, args)]
args = tuple(x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
for ns, x in zip(in_names, args))
all_args, in_tree = tree_flatten((out_cts, args))
@lu.wrap_init
def fun_trans(out_cts, args):
def fun_trans_callable(out_cts, args):
res, undefs = partition_list(map(ad.is_undefined_primal, args), args)
jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits(
pe.close_jaxpr(jaxpr), map(ad.is_undefined_primal, args), False)
@ -1690,6 +1691,8 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
for ns, x in zip(in_names, out)]
return out
fun_trans = lu.wrap_init(fun_trans_callable,
debug_info=jaxpr.debug_info)
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree)
@ -1986,8 +1989,11 @@ class RewriteTrace(core.Trace):
out_reps = out_reps[:len(out_reps) // 2]
return map(partial(RewriteTracer, self), out_reps, out_vals)
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun,
fwd: lu.WrappedFun, bwd: lu.WrappedFun,
tracers,
out_trees: Callable[[], Sequence[PyTreeDef]],
symbolic_zeros: bool):
if symbolic_zeros:
msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and "
"as a temporary workaround pass the check_rep=False argument to "
@ -2055,13 +2061,15 @@ def _replication_rewrite_nomatch(
jaxpr: core.ClosedJaxpr,
in_rep: Sequence[set[AxisName]],
) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]:
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
debug_info=jaxpr.jaxpr.debug_info)
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts), out_rep()
@lu.transformation_with_aux2
def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals):
def _rewrite_subtrace(f: Callable, store: lu.Store,
tag: core.TraceTag, mesh: Mesh, in_reps, *in_vals):
with core.take_current_trace() as parent_trace:
assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals))
t = RewriteTrace(parent_trace, tag, mesh)
@ -2072,13 +2080,14 @@ def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals):
store.store(out_reps)
return out_vals
def _rewrite_bwd(bwd, mesh, in_reps, reps_dst):
def _rewrite_bwd(bwd: lu.WrappedFun,
mesh: Mesh, in_reps, reps_dst) -> lu.WrappedFun:
def new_bwd(*args):
tag = core.TraceTag()
bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), tag, mesh, in_reps())
bwd_, reps_thunk = _rewrite_subtrace(bwd, tag, mesh, in_reps())
out = bwd_.call_wrapped(*args)
return map(_match_replication, reps_thunk(), reps_dst, out)
return new_bwd
return lu.wrap_init(new_bwd, debug_info=bwd.debug_info)
def _match_replication(src, dst, x):
if dst - src:

View File

@ -865,7 +865,7 @@ sparse_rules_bcoo[sparse.todense_p] = _todense_sparse_rule
sparse_rules_bcsr[sparse.todense_p] = _todense_sparse_rule
def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
call_jaxpr = params.pop('call_jaxpr')
call_jaxpr: core.ClosedJaxpr = params.pop('call_jaxpr')
jvp_jaxpr_thunk = params.pop('jvp_jaxpr_thunk')
num_consts = params.pop('num_consts')
sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, call_jaxpr, *spvalues)
@ -874,7 +874,7 @@ def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
sparrs = arrays_to_spvalues(spenv, arrs)
out = eval_sparse(call_jaxpr.jaxpr, call_jaxpr.consts, sparrs, spenv)
return spvalues_to_arrays(spenv, out)
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk)
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
invals = spvalues_to_arrays(spenv, spvalues)
outvals = jax.custom_derivatives.custom_jvp_call_p.bind(fun, jvp, *invals, **params)
return arrays_to_spvalues(spenv, outvals)

View File

@ -96,8 +96,12 @@ class TracerSpy:
def __init__(self):
self.tracers = []
def append(self, t: core.Tracer):
def append(self, t: Any):
try:
# We plan to do boolean conversion and catch the exception, but this works
# only for scalars
if isinstance(t, core.Tracer) and t.shape:
t = jnp.sum(t)
if t:
pass
assert False, t
@ -862,6 +866,32 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"),
])
def test_vjp_remat(self):
tracer_spy = TracerSpy()
def apply_fn(inp):
tracer_spy.append(inp)
def to_remat(x):
tracer_spy.append(x)
return jax.nn.relu(x * x)
fn = jax.checkpoint(to_remat)
return jax.vjp(fn, inp)
self._check_tracers_and_jaxprs(
jax.jit(apply_fn),
2.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
# TODO(necula): what are these flat_index components?
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][<flat index 0>][0][<flat index 0>][0][0]",
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*/nn/functions.py:.*, arg_names=x, result_paths="),
re.compile(r"traced_for=jit, fun=relu at .*/nn/functions.py:.*, arg_names=x, result_paths="),
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=checkpoint / remat, fun=to_remat, arg_names=x, from x",
"traced_for=jit, fun=apply_fn, arg_names=inp, from inp",
])
def test_custom_jvp(self):
tracer_spy = TracerSpy()
@jax.custom_jvp
@ -959,7 +989,7 @@ class DebugInfoTest(jtu.JaxTestCase):
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']",
# TODO(necula): from None
# TODO(necula): from None?
"traced_for=jit, fun=to_diff, arg_names=x['a'], from None",
"traced_for=jit, fun=to_diff, arg_names=x['a'], from x['a']",
])
@ -1435,6 +1465,33 @@ class DebugInfoTest(jtu.JaxTestCase):
],
)
def test_hessian(self):
tracer_spy = TracerSpy()
def my_f(x):
tracer_spy.append(x)
return jnp.square(x).mean()
x = jax.random.uniform(jax.random.key(0), shape=(8, 4))
self._check_tracers_and_jaxprs(
jax.jit(jax.hessian(jax.jit(my_f))),
x,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
# TODO(necula): arg_names and result_paths?
"traced_for=jit, fun=my_f, arg_names=None,x, result_paths=,",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,",
],
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, from x",
],
)
(x).block_until_ready()
def test_remat(self):
tracer_spy = TracerSpy()
def my_f(x):
@ -1506,7 +1563,6 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=",
"traced_for=shard_map, fun=my_f, arg_names=x, result_paths=",
"traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=",
"None", # TODO(necula): missing
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
@ -1657,6 +1713,7 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_tracer_debug_infos=[
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x",
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x",
"traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=b",
"None", # TODO(necula): there are missing debug info
])