mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[better_errors] Continue adding debug info to Jaxprs (step 3)
This follows after #26078, and #26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This commit is contained in:
parent
0fb278a0b9
commit
904b74860c
@ -573,7 +573,11 @@ def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0,
|
||||
|
||||
@wraps(fun, docstr=docstr, argnums=argnums)
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f = lu.wrap_init(
|
||||
fun, kwargs,
|
||||
debug_info=debug_info(
|
||||
"jacfwd", fun, args, kwargs,
|
||||
static_argnums=(argnums,) if isinstance(argnums, int) else argnums))
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
|
||||
@ -661,7 +665,11 @@ def jacrev(fun: Callable, argnums: int | Sequence[int] = 0,
|
||||
|
||||
@wraps(fun, docstr=docstr, argnums=argnums)
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f = lu.wrap_init(
|
||||
fun, kwargs,
|
||||
debug_info=debug_info(
|
||||
"jacrev", fun, args, kwargs,
|
||||
static_argnums=(argnums,) if isinstance(argnums, int) else argnums))
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
|
||||
|
@ -1065,18 +1065,20 @@ def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
|
||||
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
|
||||
return jvp
|
||||
|
||||
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
|
||||
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
|
||||
symbolic_zeros):
|
||||
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals,
|
||||
fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk, num_consts,
|
||||
bwd: lu.WrappedFun, out_trees,
|
||||
symbolic_zeros: bool):
|
||||
err_vals, err_tree = jtu.tree_flatten(in_err)
|
||||
num_errs = err_tree.num_leaves
|
||||
checkified_fun = lu.wrap_init(
|
||||
functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
|
||||
fun_jaxpr.consts, enabled_errors, err_tree))
|
||||
fun_jaxpr.consts, enabled_errors, err_tree),
|
||||
debug_info=fun_jaxpr.jaxpr.debug_info)
|
||||
checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk(
|
||||
checkified_fun)
|
||||
|
||||
@lu.wrap_init
|
||||
def checkified_fwd(*args):
|
||||
# TODO(lenamartens, sharadmv): why not checkify here?
|
||||
xs, zeros = args[::2], args[1::2]
|
||||
@ -1085,10 +1087,15 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
|
||||
xs_without_consts = xs[num_consts:]
|
||||
return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)
|
||||
|
||||
bwd_ = lambda *args: (*(None,)*num_errs, *bwd(*args))
|
||||
checkified_fwd, fwd_out_tree = flatten_fun_output(checkified_fwd)
|
||||
# TODO(necula): the fwd result_paths are not quite the same as fun_jaxpr
|
||||
checkified_fwd_wrapped = lu.wrap_init(checkified_fwd,
|
||||
debug_info=fun_jaxpr.jaxpr.debug_info)
|
||||
bwd_ = lu.wrap_init(lambda *args: (*(None,)*num_errs, *bwd.call_wrapped(*args)),
|
||||
debug_info=bwd.debug_info)
|
||||
checkified_fwd_wrapped, fwd_out_tree = flatten_fun_output(checkified_fwd_wrapped)
|
||||
all_outs = custom_derivatives.custom_vjp_call_p.bind(
|
||||
checkified_fun, checkified_fwd, bwd_, *err_vals, *in_vals, out_trees=out_trees,
|
||||
checkified_fun, checkified_fwd_wrapped,
|
||||
bwd_, *err_vals, *in_vals, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros)
|
||||
fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
|
||||
if fst:
|
||||
|
@ -392,7 +392,7 @@ class JaxprEqn:
|
||||
|
||||
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
|
||||
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
|
||||
ctx=None):
|
||||
ctx=None) -> JaxprEqn:
|
||||
source_info = source_info or source_info_util.new_source_info()
|
||||
ctx = ctx or JaxprEqnContext(
|
||||
compute_on.current_compute_type(),
|
||||
|
@ -374,15 +374,16 @@ class CustomJVPCallPrimitive(core.Primitive):
|
||||
|
||||
def get_bind_params(self, params):
|
||||
new_params = dict(params)
|
||||
call_jaxpr = new_params.pop('call_jaxpr')
|
||||
num_consts = new_params.pop('num_consts')
|
||||
call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr')
|
||||
num_consts: int = new_params.pop('num_consts')
|
||||
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
|
||||
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr))
|
||||
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk)
|
||||
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr),
|
||||
debug_info=call_jaxpr.jaxpr.debug_info)
|
||||
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
|
||||
return [fun, jvp], new_params
|
||||
|
||||
def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable) -> lu.WrappedFun:
|
||||
@lu.wrap_init
|
||||
def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable,
|
||||
debug_info: core.DebugInfo | None) -> lu.WrappedFun:
|
||||
def jvp(*xs):
|
||||
n, ragged = divmod(len(xs), 2)
|
||||
assert not ragged
|
||||
@ -398,7 +399,7 @@ def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable) -> lu.WrappedFun:
|
||||
for p, z in zip(out_primals, out_zeros)]
|
||||
assert next(nz_out_tangents_, None) is None
|
||||
return [*out_primals, *out_tangents]
|
||||
return jvp
|
||||
return lu.wrap_init(jvp, debug_info=debug_info)
|
||||
|
||||
effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)
|
||||
|
||||
@ -435,8 +436,9 @@ def _custom_jvp_call_transpose(params, jaxpr, args, ct, _):
|
||||
ad.primitive_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose
|
||||
|
||||
@weakref_lru_cache
|
||||
def _cached_closed_call_dce_instantiate(jaxpr_, used_outputs: tuple[bool, ...]
|
||||
) -> tuple[core.ClosedJaxpr, list[bool]]:
|
||||
def _cached_closed_call_dce_instantiate(jaxpr_: core.ClosedJaxpr,
|
||||
used_outputs: tuple[bool, ...]
|
||||
) -> tuple[core.ClosedJaxpr, list[bool]]:
|
||||
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts
|
||||
new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs, True)
|
||||
return core.ClosedJaxpr(new_jaxpr, consts), used_inputs
|
||||
@ -673,7 +675,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
flat_fwd, out_trees = _flatten_fwd(
|
||||
fwd_, self.nondiff_argnums, self.symbolic_zeros, debug_fun,
|
||||
debug_fwd, in_tree, out_type)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees,
|
||||
symbolic_zeros=self.symbolic_zeros)
|
||||
@ -940,7 +942,9 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
|
||||
def _custom_vjp_call_jaxpr_jvp(
|
||||
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
|
||||
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
|
||||
num_consts: int, bwd: lu.WrappedFun,
|
||||
out_trees: Callable[[], Sequence[PyTreeDef]],
|
||||
symbolic_zeros: bool):
|
||||
_, args = split_list(primals, [num_consts])
|
||||
consts_dot, args_dot = split_list(tangents, [num_consts])
|
||||
if any(type(t) is not Zero for t in consts_dot):
|
||||
@ -963,7 +967,8 @@ def _custom_vjp_call_jaxpr_vmap(
|
||||
axis_data, args, in_dims, *,
|
||||
fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
|
||||
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
|
||||
num_consts: int, bwd: lu.WrappedFun,
|
||||
out_trees: Callable, symbolic_zeros: bool):
|
||||
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
||||
else x for x, d in zip(args, in_dims)]
|
||||
in_batched = [d is not not_mapped for d in in_dims]
|
||||
@ -1000,12 +1005,13 @@ def _custom_vjp_call_jaxpr_dce(
|
||||
) -> tuple[list[bool], core.JaxprEqn | None]:
|
||||
if not any(used_outs) and not pe.has_effects(eqn):
|
||||
return [False] * len(eqn.invars), None
|
||||
|
||||
fun_jaxpr = eqn.params["fun_jaxpr"]
|
||||
fun_jaxpr: core.ClosedJaxpr = eqn.params["fun_jaxpr"]
|
||||
fwd_jaxpr_thunk = eqn.params["fwd_jaxpr_thunk"]
|
||||
bwd = eqn.params["bwd"]
|
||||
out_trees = eqn.params["out_trees"]
|
||||
symbolic_zeros = eqn.params["symbolic_zeros"]
|
||||
bwd: lu.WrappedFun = eqn.params["bwd"]
|
||||
out_trees: Callable[[], Sequence[PyTreeDef]] = eqn.params["out_trees"]
|
||||
symbolic_zeros: bool = eqn.params["symbolic_zeros"]
|
||||
dce_fun_jaxpr: core.ClosedJaxpr
|
||||
used_ins: Sequence[bool]
|
||||
dce_fun_jaxpr, used_ins = _cached_closed_call_dce_instantiate(
|
||||
fun_jaxpr, tuple(used_outs))
|
||||
assert all(used_ins)
|
||||
@ -1019,7 +1025,6 @@ def _custom_vjp_call_jaxpr_dce(
|
||||
fwd_jaxpr, (True,) * num_res + tuple(used_outs))
|
||||
return dce_fwd_jaxpr.jaxpr, dce_fwd_jaxpr.consts
|
||||
|
||||
@lu.wrap_init
|
||||
def dce_bwd(*args):
|
||||
_, res_tree = out_trees()
|
||||
res, cts = split_list(args, [res_tree.num_leaves])
|
||||
@ -1035,19 +1040,21 @@ def _custom_vjp_call_jaxpr_dce(
|
||||
else:
|
||||
all_cts.append(zeros_like_aval(ct_aval))
|
||||
assert next(cts_, None) is None
|
||||
return bwd(*res, *all_cts)
|
||||
return bwd.call_wrapped(*res, *all_cts)
|
||||
|
||||
dce_bwd_wrapped = lu.wrap_init(dce_bwd,
|
||||
debug_info=bwd.debug_info)
|
||||
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
|
||||
new_params = dict(
|
||||
eqn.params,
|
||||
fun_jaxpr=dce_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=dce_fwd_jaxpr_thunk,
|
||||
bwd=dce_bwd.call_wrapped,
|
||||
bwd=dce_bwd_wrapped,
|
||||
)
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
eqn.invars, outvars, eqn.primitive, new_params, dce_fun_jaxpr.effects,
|
||||
eqn.source_info, eqn.ctx)
|
||||
return used_ins, new_eqn
|
||||
return list(used_ins), new_eqn
|
||||
pe.dce_rules[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_dce
|
||||
|
||||
xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)
|
||||
@ -1125,7 +1132,9 @@ def custom_gradient(fun):
|
||||
def fwd(*args, **kwargs):
|
||||
ans, rule = fun(*args, **kwargs)
|
||||
ans_flat, out_tree = tree_flatten((ans,))
|
||||
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
|
||||
debug_fwd = debug_info("custom_gradient fwd", rule, (ans,), {})
|
||||
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule,
|
||||
debug_info=debug_fwd), out_tree)
|
||||
ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat]
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
|
||||
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
|
||||
@ -1224,10 +1233,11 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
|
||||
"""
|
||||
flat_args, in_tree = tree_flatten(example_args)
|
||||
in_avals = tuple(map(core.get_aval, flat_args))
|
||||
debug = debug_info("closure_convert", fun, example_args, {})
|
||||
if config.check_tracer_leaks.value:
|
||||
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
|
||||
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals, debug)
|
||||
else:
|
||||
return _closure_convert_for_avals(fun, in_tree, in_avals)
|
||||
return _closure_convert_for_avals(fun, in_tree, in_avals, debug)
|
||||
|
||||
def _maybe_perturbed(x: Any) -> bool:
|
||||
# False if x can't represent an AD-perturbed value (i.e. a value
|
||||
@ -1251,8 +1261,10 @@ def _maybe_perturbed(x: Any) -> bool:
|
||||
return True # We can't be sure!
|
||||
|
||||
@cache()
|
||||
def _closure_convert_for_avals(fun, in_tree, in_avals):
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
def _closure_convert_for_avals(fun, in_tree, in_avals,
|
||||
debug_info: core.DebugInfo):
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun, debug_info=debug_info),
|
||||
in_tree)
|
||||
jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
||||
out_tree = out_tree()
|
||||
|
||||
|
@ -25,7 +25,7 @@ from jax._src import config
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten,
|
||||
register_pytree_node, Partial)
|
||||
register_pytree_node, Partial, PyTreeDef)
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src.ad_util import (
|
||||
@ -644,26 +644,28 @@ class LinearizeTrace(Trace):
|
||||
return [maybe_linearize_tracer(self, x, nz, t)
|
||||
for x, nz, t in zip(primals_out, tangent_nzs_out, tangents_out)]
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
def process_custom_vjp_call(self, prim, fun, fwd,
|
||||
bwd: lu.WrappedFun, tracers,
|
||||
out_trees: Callable[[], Sequence[PyTreeDef]],
|
||||
symbolic_zeros: bool):
|
||||
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
|
||||
if all(type(t) is Zero for t in tangents_in):
|
||||
return prim.bind_with_trace(self.parent_trace,
|
||||
(fun, fwd, bwd, *primals_in),
|
||||
dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
|
||||
fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)]
|
||||
fwd_in = [x for pair in fwd_in for x in pair] # flatten
|
||||
fwd_in_flat = [x for pair in fwd_in for x in pair] # flatten
|
||||
with core.set_current_trace(self.parent_trace):
|
||||
res_and_primals_out = fwd.call_wrapped(*fwd_in)
|
||||
res_and_primals_out = fwd.call_wrapped(*fwd_in_flat)
|
||||
|
||||
_, res_tree = out_trees()
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
|
||||
|
||||
tangents_in = map(instantiate_zeros, tangents_in)
|
||||
tangents_in_zeros = map(instantiate_zeros, tangents_in)
|
||||
with core.set_current_trace(self.tangent_trace):
|
||||
tangents_out = custom_lin_p.bind(
|
||||
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
*res, *tangents_in_zeros, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
||||
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
|
||||
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)
|
||||
@ -975,9 +977,13 @@ def nonzero_outputs(f, store, *args, **kwargs):
|
||||
store.store([type(r) is not Zero for r in results])
|
||||
return results
|
||||
|
||||
def map_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
def map_transpose(primitive: core.Primitive, params,
|
||||
call_jaxpr: core.Jaxpr, args, ct, _):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, False)
|
||||
# TODO(necula): use the right debug_info for the backwards pass
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
|
||||
debug_info=call_jaxpr.debug_info),
|
||||
call_jaxpr, False)
|
||||
fun, nz_arg_cts = nonzero_outputs(fun)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
|
||||
@ -1056,12 +1062,24 @@ def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents):
|
||||
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
|
||||
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
|
||||
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
|
||||
new_debug_info = jaxpr.jaxpr.debug_info
|
||||
if new_debug_info is not None:
|
||||
new_arg_names = tuple(_perm(primals_in, tangents_in,
|
||||
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
|
||||
new_result_paths = tuple(_perm(primals_out, tangents_out,
|
||||
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
|
||||
new_debug_info = new_debug_info._replace(
|
||||
arg_names=new_arg_names,
|
||||
result_paths=new_result_paths,
|
||||
)
|
||||
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
|
||||
new_invars, new_outvars, jaxpr.jaxpr.eqns,
|
||||
jaxpr.jaxpr.effects)
|
||||
jaxpr.jaxpr.effects,
|
||||
new_debug_info)
|
||||
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
|
||||
|
||||
def _perm(primal_counts, tangent_counts, lst):
|
||||
def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int],
|
||||
lst: Sequence[Any]) -> Sequence[Any]:
|
||||
n = sum(primal_counts)
|
||||
primals, tangents = lst[:n], lst[n:]
|
||||
primal_groups = split_list(primals, primal_counts[:-1])
|
||||
@ -1082,14 +1100,15 @@ def raise_custom_vjp_error_on_jvp(*_, **__):
|
||||
"function.")
|
||||
custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp)
|
||||
|
||||
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals,
|
||||
def _custom_lin_transpose(cts_out, *invals, num_res,
|
||||
bwd: lu.WrappedFun, out_avals,
|
||||
symbolic_zeros):
|
||||
res, _ = split_list(invals, [num_res])
|
||||
if symbolic_zeros:
|
||||
cts_out = map(replace_internal_symbolic_zeros, cts_out)
|
||||
else:
|
||||
cts_out = map(instantiate_zeros, cts_out)
|
||||
cts_in = bwd(*res, *cts_out)
|
||||
cts_in = bwd.call_wrapped(*res, *cts_out)
|
||||
cts_in = map(replace_rule_output_symbolic_zeros, cts_in)
|
||||
return [None] * num_res + list(cts_in)
|
||||
primitive_transposes[custom_lin_p] = _custom_lin_transpose
|
||||
|
@ -799,8 +799,11 @@ def batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
|
||||
return _batch_jaxpr_axes(closed_jaxpr, axis_data, tuple(in_axes), tuple(out_axes_dest))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
def _batch_jaxpr_axes(closed_jaxpr: core.ClosedJaxpr,
|
||||
axis_data: AxisData,
|
||||
in_axes: Sequence[int], out_axes_dest: Sequence[int]):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
|
||||
debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_data)
|
||||
f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes)
|
||||
f = _batch_jaxpr_outer(f, axis_data, in_axes)
|
||||
@ -896,7 +899,10 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
|
||||
store.store(out_dims * 2)
|
||||
return out_primals + out_tangents
|
||||
|
||||
def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
|
||||
def batch_custom_vjp_bwd(bwd: lu.WrappedFun, tag: core.TraceTag,
|
||||
axis_data: AxisData,
|
||||
in_dims: Callable[[], Sequence[int | None]],
|
||||
out_dim_dests: Sequence[int | None]) -> lu.WrappedFun:
|
||||
axis_size = axis_data.size
|
||||
axis_name = axis_data.name
|
||||
mesh_axis = axis_data.explicit_mesh_axis
|
||||
@ -907,11 +913,11 @@ def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
|
||||
for x, dim in zip(args, in_dims_)]
|
||||
in_dims_ = [None if type(x) is SymbolicZero else d
|
||||
for x, d in zip(args, in_dims_)]
|
||||
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd), tag, axis_data, in_dims_)
|
||||
bwd_, out_dims_thunk = batch_subtrace(bwd, tag, axis_data, in_dims_)
|
||||
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, mesh_axis,
|
||||
out_dims_thunk, out_dim_dests)
|
||||
return bwd_.call_wrapped(*args)
|
||||
return new_bwd
|
||||
return lu.wrap_init(new_bwd, debug_info=bwd.debug_info)
|
||||
|
||||
@lu.transformation2
|
||||
def _match_axes_and_sum(f, axis_size, axis_name, mesh_axis, out_dims_thunk,
|
||||
|
@ -2022,8 +2022,11 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
def process_custom_vjp_call(self, prim: core.Primitive,
|
||||
fun: lu.WrappedFun,
|
||||
fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers,
|
||||
out_trees: Callable[[], Sequence[PyTreeDef]],
|
||||
symbolic_zeros: bool):
|
||||
tracers = map(self.to_jaxpr_tracer, tracers)
|
||||
in_avals = [t.aval for t in tracers]
|
||||
fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
@ -2041,7 +2044,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
invars = map(self.getvar, tracers)
|
||||
constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts))
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
|
||||
prim.initial_style, # type: ignore[attribute-error]
|
||||
dict(fun_jaxpr=closed_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_from_zeros,
|
||||
num_consts=len(consts),
|
||||
|
@ -586,8 +586,10 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
|
||||
for p, nz in zip(primals_out, nonzeros_out)]
|
||||
return primals_out, tangents_out
|
||||
|
||||
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
jaxpr, linear, unroll, _split_transpose):
|
||||
def _scan_partial_eval(trace, *tracers, reverse: bool,
|
||||
length: int, num_consts: int, num_carry: int,
|
||||
jaxpr: core.ClosedJaxpr, linear: Sequence[bool],
|
||||
unroll: int, _split_transpose: bool):
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
unknowns = [not t.pval.is_known() for t in tracers]
|
||||
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
|
||||
@ -612,8 +614,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
del res_avals, carry_uk_out
|
||||
|
||||
# Instantiate those inputs which must be treated as unknown from the fixpoint.
|
||||
tracers = [trace.instantiate_const(t) if uk else t
|
||||
for t, uk in zip(tracers, unknowns)]
|
||||
tracers = tuple(trace.instantiate_const(t) if uk else t
|
||||
for t, uk in zip(tracers, unknowns))
|
||||
|
||||
# The residual inputs and outputs of the jaxprs produced haven't yet been
|
||||
# adapted to the scan calling convention; in particular, jaxpr_known has its
|
||||
@ -638,7 +640,9 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
for aval in jaxpr_known.in_avals[len(const_pvals):]]
|
||||
with source_info_util.reset_name_stack():
|
||||
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
|
||||
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals,
|
||||
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known),
|
||||
debug_info=jaxpr_known.jaxpr.debug_info),
|
||||
const_pvals + other_pvals,
|
||||
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
|
||||
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
|
||||
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
|
||||
@ -880,8 +884,9 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
|
||||
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
|
||||
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
|
||||
@weakref_lru_cache
|
||||
def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2,
|
||||
ct_ys_is_zeros):
|
||||
def _transpose_scan_jaxpr(jaxpr: core.ClosedJaxpr,
|
||||
num_res1: int, num_c: int, num_res2: int,
|
||||
ct_ys_is_zeros: Sequence[bool]):
|
||||
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
|
||||
# TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals
|
||||
# if an axis isn't reduced
|
||||
@ -896,7 +901,6 @@ def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2,
|
||||
aval for aval, is_zero in zip(b_ys_avals, ct_ys_is_zeros) if not is_zero
|
||||
]
|
||||
|
||||
@lu.wrap_init
|
||||
def transposed(*res1_cbar_bbar_res2):
|
||||
res1, c_bar, b_bar, ys_bar_stripped, res2 = split_list(
|
||||
res1_cbar_bbar_res2,
|
||||
@ -915,9 +919,14 @@ def _transpose_scan_jaxpr(jaxpr, num_res1, num_c, num_res2,
|
||||
a_bar = _map(ad.instantiate_zeros, a_bar)
|
||||
c_bar = _map(ad.instantiate_zeros, _map(ad.add_tangents, c_bar, new_c_bar))
|
||||
return c_bar + a_bar
|
||||
|
||||
# TODO(necula): fix arg names and results for transposed
|
||||
transposed_wrapped = lu.wrap_init(transposed,
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
return _make_closed_jaxpr_attrs(
|
||||
transposed, tuple(res1_avals + c_avals + b_carry_avals +
|
||||
b_ys_avals_stripped + res2_avals))
|
||||
transposed_wrapped,
|
||||
tuple(res1_avals + c_avals + b_carry_avals +
|
||||
b_ys_avals_stripped + res2_avals))
|
||||
|
||||
|
||||
def _scan_batching_rule(axis_data, args,
|
||||
|
@ -1923,6 +1923,8 @@ def reduce(operands: Any,
|
||||
is undefined.
|
||||
"""
|
||||
flat_operands, operand_tree = tree_util.tree_flatten(operands)
|
||||
comp_debug = api_util.debug_info("reduce comp", computation,
|
||||
(init_values, init_values), {})
|
||||
flat_init_values, init_value_tree = tree_util.tree_flatten(init_values)
|
||||
if operand_tree != init_value_tree:
|
||||
raise ValueError('Operands must have the same tree structure as init_values:'
|
||||
@ -1939,7 +1941,7 @@ def reduce(operands: Any,
|
||||
else:
|
||||
flat_init_avals = safe_map(core.get_aval, flat_init_values)
|
||||
closed_jaxpr, out_tree = _variadic_reduction_jaxpr(
|
||||
computation, tuple(flat_init_avals), init_value_tree)
|
||||
computation, comp_debug, tuple(flat_init_avals), init_value_tree)
|
||||
out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation,
|
||||
jaxpr=closed_jaxpr, dimensions=tuple(dimensions))
|
||||
return tree_util.tree_unflatten(out_tree, out)
|
||||
@ -1967,10 +1969,13 @@ def _reduction_jaxpr(computation: Callable,
|
||||
return jaxpr, tuple(consts)
|
||||
|
||||
@cache()
|
||||
def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
|
||||
def _variadic_reduction_jaxpr(computation: Callable[[Any, Any], Any],
|
||||
debug_info: core.DebugInfo,
|
||||
flat_avals,
|
||||
aval_tree: tree_util.PyTreeDef):
|
||||
avals = tree_util.tree_unflatten(aval_tree, flat_avals)
|
||||
flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals))
|
||||
comp = lu.wrap_init(computation)
|
||||
comp = lu.wrap_init(computation, debug_info=debug_info)
|
||||
flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
@ -5921,7 +5926,7 @@ def _argminmax_dtype_rule(operand, *, axes, index_dtype):
|
||||
|
||||
class _ArgMinMaxReducer:
|
||||
|
||||
def __init__(self, value_comparator):
|
||||
def __init__(self, value_comparator: Callable[[Any, Any], Any]):
|
||||
self._value_comparator = value_comparator
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -19,6 +19,7 @@ from functools import partial
|
||||
import warnings
|
||||
|
||||
from jax import tree_util
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
@ -55,6 +56,8 @@ def _reduce_window(
|
||||
window_dilation: Sequence[int] | None = None,
|
||||
):
|
||||
flat_operands, operand_tree = tree_util.tree_flatten(operand)
|
||||
comp_debug = api_util.debug_info("reduce_window comp", computation,
|
||||
(init_value, init_value), {})
|
||||
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
|
||||
if operand_tree != init_value_tree:
|
||||
raise ValueError(
|
||||
@ -88,7 +91,7 @@ def _reduce_window(
|
||||
else:
|
||||
flat_init_avals = map(core.get_aval, flat_init_values)
|
||||
jaxpr, out_tree = lax._variadic_reduction_jaxpr(
|
||||
computation, tuple(flat_init_avals), init_value_tree
|
||||
computation, comp_debug, tuple(flat_init_avals), init_value_tree
|
||||
)
|
||||
if operand_tree != out_tree:
|
||||
raise ValueError(
|
||||
|
@ -326,7 +326,7 @@ def resolve_physical_types(jaxpr: jax_core.Jaxpr, consts: Sequence[Any]):
|
||||
interp_fun = partial(
|
||||
eval_jaxpr_recursive, jaxpr, consts,
|
||||
recurse_hop_rule=resolve_physical_types)
|
||||
wrapped = lu.wrap_init(interp_fun)
|
||||
wrapped = lu.wrap_init(interp_fun, debug_info=jaxpr.debug_info)
|
||||
new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
wrapped, kernel_avals)
|
||||
return new_jaxpr, new_consts
|
||||
|
@ -30,7 +30,6 @@ package(
|
||||
py_library(
|
||||
name = "jax2tf",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":jax2tf_internal"],
|
||||
)
|
||||
@ -42,7 +41,6 @@ py_library(
|
||||
"impl_no_xla.py",
|
||||
"jax2tf.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
# TODO: b/255503696: enable pytype
|
||||
tags = ["pytype_unchecked_annotations"],
|
||||
visibility = jax_visibility("jax2tf_internal"),
|
||||
|
@ -24,7 +24,6 @@ package(
|
||||
py_library(
|
||||
name = "back_compat_testdata",
|
||||
srcs = glob(["*.py"]),
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//third_party/py/numpy",
|
||||
"//third_party/py/typing_extensions",
|
||||
|
@ -28,7 +28,6 @@ package(
|
||||
py_library(
|
||||
name = "flax_models",
|
||||
srcs = glob(["*.py"]),
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//jax",
|
||||
"//third_party/py/flax:core",
|
||||
|
@ -1642,8 +1642,8 @@ def _promote_scalar_residuals(f: Callable, *args, **kwargs):
|
||||
for x in out_consts]
|
||||
return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
|
||||
|
||||
def _promote_scalar_residuals_jaxpr(jaxpr, which):
|
||||
@lu.wrap_init
|
||||
def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr,
|
||||
which: Sequence[bool]):
|
||||
def fun(*res_and_args):
|
||||
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
|
||||
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
|
||||
@ -1651,7 +1651,8 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which):
|
||||
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
|
||||
for v, w in zip(jaxpr.constvars, which)]
|
||||
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals)
|
||||
return jaxpr
|
||||
|
||||
|
||||
@ -1663,20 +1664,20 @@ def _unmentioned2(mesh: Mesh, names: AxisNames,
|
||||
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]
|
||||
|
||||
|
||||
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
def _shard_map_transpose(out_cts, *args,
|
||||
jaxpr: core.Jaxpr, mesh, in_names, out_names,
|
||||
check_rep, rewrite, auto):
|
||||
mb_div = lambda x, y: x / y if y != 1 else x
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite or dtypes.dtype(x) == dtypes.float0
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
args = [x if type(x) is not ad.UndefinedPrimal else
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
|
||||
for ns, x in zip(in_names, args)]
|
||||
args = tuple(x if type(x) is not ad.UndefinedPrimal else
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
|
||||
for ns, x in zip(in_names, args))
|
||||
all_args, in_tree = tree_flatten((out_cts, args))
|
||||
|
||||
@lu.wrap_init
|
||||
def fun_trans(out_cts, args):
|
||||
def fun_trans_callable(out_cts, args):
|
||||
res, undefs = partition_list(map(ad.is_undefined_primal, args), args)
|
||||
jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits(
|
||||
pe.close_jaxpr(jaxpr), map(ad.is_undefined_primal, args), False)
|
||||
@ -1690,6 +1691,8 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
for ns, x in zip(in_names, out)]
|
||||
return out
|
||||
|
||||
fun_trans = lu.wrap_init(fun_trans_callable,
|
||||
debug_info=jaxpr.debug_info)
|
||||
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
|
||||
fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree)
|
||||
|
||||
@ -1986,8 +1989,11 @@ class RewriteTrace(core.Trace):
|
||||
out_reps = out_reps[:len(out_reps) // 2]
|
||||
return map(partial(RewriteTracer, self), out_reps, out_vals)
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun,
|
||||
fwd: lu.WrappedFun, bwd: lu.WrappedFun,
|
||||
tracers,
|
||||
out_trees: Callable[[], Sequence[PyTreeDef]],
|
||||
symbolic_zeros: bool):
|
||||
if symbolic_zeros:
|
||||
msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and "
|
||||
"as a temporary workaround pass the check_rep=False argument to "
|
||||
@ -2055,13 +2061,15 @@ def _replication_rewrite_nomatch(
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_rep: Sequence[set[AxisName]],
|
||||
) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]:
|
||||
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
|
||||
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
|
||||
return core.ClosedJaxpr(jaxpr_, consts), out_rep()
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals):
|
||||
def _rewrite_subtrace(f: Callable, store: lu.Store,
|
||||
tag: core.TraceTag, mesh: Mesh, in_reps, *in_vals):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals))
|
||||
t = RewriteTrace(parent_trace, tag, mesh)
|
||||
@ -2072,13 +2080,14 @@ def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals):
|
||||
store.store(out_reps)
|
||||
return out_vals
|
||||
|
||||
def _rewrite_bwd(bwd, mesh, in_reps, reps_dst):
|
||||
def _rewrite_bwd(bwd: lu.WrappedFun,
|
||||
mesh: Mesh, in_reps, reps_dst) -> lu.WrappedFun:
|
||||
def new_bwd(*args):
|
||||
tag = core.TraceTag()
|
||||
bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), tag, mesh, in_reps())
|
||||
bwd_, reps_thunk = _rewrite_subtrace(bwd, tag, mesh, in_reps())
|
||||
out = bwd_.call_wrapped(*args)
|
||||
return map(_match_replication, reps_thunk(), reps_dst, out)
|
||||
return new_bwd
|
||||
return lu.wrap_init(new_bwd, debug_info=bwd.debug_info)
|
||||
|
||||
def _match_replication(src, dst, x):
|
||||
if dst - src:
|
||||
|
@ -865,7 +865,7 @@ sparse_rules_bcoo[sparse.todense_p] = _todense_sparse_rule
|
||||
sparse_rules_bcsr[sparse.todense_p] = _todense_sparse_rule
|
||||
|
||||
def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
|
||||
call_jaxpr = params.pop('call_jaxpr')
|
||||
call_jaxpr: core.ClosedJaxpr = params.pop('call_jaxpr')
|
||||
jvp_jaxpr_thunk = params.pop('jvp_jaxpr_thunk')
|
||||
num_consts = params.pop('num_consts')
|
||||
sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, call_jaxpr, *spvalues)
|
||||
@ -874,7 +874,7 @@ def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
|
||||
sparrs = arrays_to_spvalues(spenv, arrs)
|
||||
out = eval_sparse(call_jaxpr.jaxpr, call_jaxpr.consts, sparrs, spenv)
|
||||
return spvalues_to_arrays(spenv, out)
|
||||
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk)
|
||||
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
|
||||
invals = spvalues_to_arrays(spenv, spvalues)
|
||||
outvals = jax.custom_derivatives.custom_jvp_call_p.bind(fun, jvp, *invals, **params)
|
||||
return arrays_to_spvalues(spenv, outvals)
|
||||
|
@ -96,8 +96,12 @@ class TracerSpy:
|
||||
def __init__(self):
|
||||
self.tracers = []
|
||||
|
||||
def append(self, t: core.Tracer):
|
||||
def append(self, t: Any):
|
||||
try:
|
||||
# We plan to do boolean conversion and catch the exception, but this works
|
||||
# only for scalars
|
||||
if isinstance(t, core.Tracer) and t.shape:
|
||||
t = jnp.sum(t)
|
||||
if t:
|
||||
pass
|
||||
assert False, t
|
||||
@ -862,6 +866,32 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"),
|
||||
])
|
||||
|
||||
def test_vjp_remat(self):
|
||||
tracer_spy = TracerSpy()
|
||||
def apply_fn(inp):
|
||||
tracer_spy.append(inp)
|
||||
def to_remat(x):
|
||||
tracer_spy.append(x)
|
||||
return jax.nn.relu(x * x)
|
||||
fn = jax.checkpoint(to_remat)
|
||||
return jax.vjp(fn, inp)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(apply_fn),
|
||||
2.,
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
# TODO(necula): what are these flat_index components?
|
||||
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][<flat index 0>][0][<flat index 0>][0][0]",
|
||||
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*/nn/functions.py:.*, arg_names=x, result_paths="),
|
||||
re.compile(r"traced_for=jit, fun=relu at .*/nn/functions.py:.*, arg_names=x, result_paths="),
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=checkpoint / remat, fun=to_remat, arg_names=x, from x",
|
||||
"traced_for=jit, fun=apply_fn, arg_names=inp, from inp",
|
||||
])
|
||||
|
||||
def test_custom_jvp(self):
|
||||
tracer_spy = TracerSpy()
|
||||
@jax.custom_jvp
|
||||
@ -959,7 +989,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']",
|
||||
# TODO(necula): from None
|
||||
# TODO(necula): from None?
|
||||
"traced_for=jit, fun=to_diff, arg_names=x['a'], from None",
|
||||
"traced_for=jit, fun=to_diff, arg_names=x['a'], from x['a']",
|
||||
])
|
||||
@ -1435,6 +1465,33 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_hessian(self):
|
||||
tracer_spy = TracerSpy()
|
||||
|
||||
def my_f(x):
|
||||
tracer_spy.append(x)
|
||||
return jnp.square(x).mean()
|
||||
|
||||
x = jax.random.uniform(jax.random.key(0), shape=(8, 4))
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(jax.hessian(jax.jit(my_f))),
|
||||
x,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
||||
# TODO(necula): arg_names and result_paths?
|
||||
"traced_for=jit, fun=my_f, arg_names=None,x, result_paths=,",
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,",
|
||||
],
|
||||
tracer_spy=tracer_spy,
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, from x",
|
||||
],
|
||||
)
|
||||
|
||||
(x).block_until_ready()
|
||||
|
||||
def test_remat(self):
|
||||
tracer_spy = TracerSpy()
|
||||
def my_f(x):
|
||||
@ -1506,7 +1563,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=",
|
||||
"traced_for=shard_map, fun=my_f, arg_names=x, result_paths=",
|
||||
"traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=",
|
||||
"None", # TODO(necula): missing
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
@ -1657,6 +1713,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x",
|
||||
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x",
|
||||
"traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=b",
|
||||
"None", # TODO(necula): there are missing debug info
|
||||
])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user