make remat_call partial-eval into one remat_call

This commit is contained in:
Matthew Johnson 2019-11-27 15:25:49 -08:00 committed by Matthew Johnson
parent b2b5049eb5
commit ac251046fc
3 changed files with 39 additions and 57 deletions

View File

@ -187,18 +187,21 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
write_primal(eqn.outvars[0], ans)
else:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
assert not any(is_linear(v) for v in const_vars)
if any(is_linear(v) for v in it.chain(eqn.invars, bound_vars)):
linear_eqns.append(eqn)
else:
assert not any(is_linear(v) for v in const_vars)
sub_consts = map(read_primal, const_vars)
sub_freevar_vals = map(read_primal, bound_vars)
in_vals = map(read_primal, eqn.invars)
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
ans = tree_unflatten(out_tree(), out_flat)
elif eqn.primitive is not pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr, map(read_primal, const_vars),
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)
# we special-case remat_call here because it can be mixed linear /
# nonlinear, so we always evaluate it even if it has a linear part
if eqn.primitive is pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr, map(read_primal, const_vars),
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)
ct_env = {}
@ -225,6 +228,13 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
cotangents_out = map(read_cotangent, jaxpr.invars)
return freevar_cts, cotangents_out
def _eval_subjaxpr_primals(prim, jaxpr, consts, freevar_vals, in_vals, params):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = prim.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
def _eval_primals(jaxpr, consts, freevar_vals, args):
primal_env = {}
@ -259,15 +269,13 @@ def _eval_primals(jaxpr, consts, freevar_vals, args):
write_primal(eqn.outvars[0], ans)
else:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
sub_consts = map(read_primal, const_vars)
sub_freevar_vals = map(read_primal, bound_vars)
in_vals = map(read_primal, eqn.invars)
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
ans = tree_unflatten(out_tree(), out_flat)
map(write_primal, eqn.outvars, ans)
assert not any(is_linear(v) for v in const_vars)
if (eqn.primitive is pe.remat_call_p or
not any(is_linear(v) for v in it.chain(eqn.invars, bound_vars))):
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr, map(read_primal, const_vars),
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)
return map(read_primal, jaxpr.outvars)
class UndefinedPrimal(object):

View File

@ -507,6 +507,9 @@ def _remat_partial_eval(trace, f, tracers, params):
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
# First, we prune the jaxpr to be staged out not to have too many outputs.
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
# Next, we need values for the outputs that should be known. Since consts
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
# minus the residual outputs that we don't need. When `concrete=True`, as an
@ -521,53 +524,25 @@ def _remat_partial_eval(trace, f, tracers, params):
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1_primals)(*in_consts)[:-num_res or None]
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call
# except we stage out two calls: one based on jaxpr_1 for computing the
# residuals (which in the case of reverse-mode ad involves no linear
# variables) and the other based on jaxpr_2 for evaluating everything given
# the residuals (which in reverse-mode ad is linear).
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
instantiated_tracers = env + instantiated_tracers
num_nonres = len(jaxpr_2.out_avals)
jaxpr_1_res = _dce_jaxpr(jaxpr_1, [False] * num_nonres + [True] * num_res,
prune_outputs=True)
const_tracers = map(trace.new_instantiated_const, consts)
bound_subjaxpr_1 = (jaxpr_1_res.jaxpr, const_tracers, ())
res_avals = jaxpr_1.out_avals[num_nonres:]
res_tracers = [JaxprTracer(trace, PartialVal((aval, unit)), None)
for aval in res_avals]
tracers_1 = [t if not uk else trace.new_instantiated_literal(unit)
for t, uk in zip(instantiated_tracers, in_unknowns)]
eqn_1 = new_eqn_recipe(tracers_1, res_tracers, remat_call_p,
(bound_subjaxpr_1,), params)
for t in res_tracers: t.recipe = eqn_1
bound_subjaxpr_2 = (jaxpr_2.jaxpr, (), ())
bound_subjaxpr = (typed_jaxpr.jaxpr, const_tracers, ())
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
tracers_2 = [t if uk else trace.new_instantiated_literal(unit)
for t, uk in zip(instantiated_tracers, in_unknowns)]
eqn_2 = new_eqn_recipe(tracers_2 + res_tracers, out_tracers, remat_call_p,
(bound_subjaxpr_2,), params)
for t in out_tracers: t.recipe = eqn_2
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
(bound_subjaxpr,), params)
for t in out_tracers: t.recipe = eqn
return out_tracers
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
# NOTE to future self: the problem with the above strategy is that the jaxpr
# produced wouldn't be round-trippable, in the sense that by forming two remat
# calls we ensured the first one would be partial-eval'd away when we tried to
# round-trip e.g. for partial eval of scan.
def _dce_jaxpr(typed_jaxpr, outputs, prune_outputs=False):
def _dce_jaxpr(typed_jaxpr, outputs):
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan, call, or other higher-order primitives.
# TODO(mattjj): better DCE
jaxpr = typed_jaxpr.jaxpr
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
if prune_outputs:
out_pairs = [(var, aval) for var, aval, output
in zip(outvars, out_avals, outputs) if output]
else:
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
for var, aval, output in zip(outvars, out_avals, outputs)]
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
for var, aval, output in zip(outvars, out_avals, outputs)]
new_outvars, new_out_avals = unzip2(out_pairs)
needed_vars = set(new_outvars)

View File

@ -1470,8 +1470,7 @@ class APITest(jtu.JaxTestCase):
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.eqns
import ipdb; ipdb.set_trace()
self.assertIn(' sin ', str(scan_eqn.params['jaxpr']))
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.eqns