try remat_call partial-eval into two remat_calls

The idea here was for the resulting jaxpr to have a purely nonlinear
remat_call and a linear one with no primals to evaluate. (I wanted to
avoid having to recurse into all calls in _eval_primal in
backward_pass.) But the issue is that makes jaxprs not round-trippable,
since the first remat_call, depending only on constants, would get
partial-eval'd away at the first attempted round-trip. And we round-trip
in partial_eval_jaxpr, particularly for partial eval of scan. That meant
remat of scan didn't work, and that's no good!
This commit is contained in:
Matthew Johnson 2019-11-27 14:28:13 -08:00 committed by Matthew Johnson
parent 9a8523603c
commit b2b5049eb5
5 changed files with 136 additions and 60 deletions

View File

@ -187,17 +187,19 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
write_primal(eqn.outvars[0], ans)
else:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
if any(is_linear(v) for v in it.chain(eqn.invars, const_vars, bound_vars)):
if any(is_linear(v) for v in it.chain(eqn.invars, bound_vars)):
linear_eqns.append(eqn)
sub_consts = map(read_primal, const_vars)
sub_freevar_vals = map(read_primal, bound_vars)
in_vals = map(read_primal, eqn.invars)
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
ans = tree_unflatten(out_tree(), out_flat)
map(write_primal, eqn.outvars, ans)
else:
assert not any(is_linear(v) for v in const_vars)
sub_consts = map(read_primal, const_vars)
sub_freevar_vals = map(read_primal, bound_vars)
in_vals = map(read_primal, eqn.invars)
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
ans = tree_unflatten(out_tree(), out_flat)
map(write_primal, eqn.outvars, ans)
ct_env = {}
map(write_cotangent, jaxpr.outvars, cotangents_in)

View File

@ -411,6 +411,14 @@ def closure_convert_jaxpr(jaxpr):
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def convert_freevars_jaxpr(jaxpr):
core.skip_checks or core.check_jaxpr(jaxpr)
lifted_jaxpr = Jaxpr(constvars=jaxpr.constvars, freevars=(),
invars=jaxpr.freevars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def partial_eval_jaxpr(jaxpr, unknowns, instantiate):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
@ -483,54 +491,83 @@ def _remat_partial_eval(trace, f, tracers, params):
fun, aux = partial_eval(f, trace, in_pvs)
out_flat = remat_call_p.bind(fun, *in_consts, **params)
out_pvs, jaxpr, env = aux()
env = map(trace.full_raise, env)
out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]
# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
in_avals = [raise_to_shaped(pv) for pv in in_pvs]
jaxpr_converted = convert_freevars_jaxpr(jaxpr)
in_avals = ([raise_to_shaped(t.pval[0]) for t in env]
+ [raise_to_shaped(pv) for pv in in_pvs])
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
for pv, const in zip(out_pvs, out_pval_consts1)]
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
in_unknowns = [t.pval[0] is not None for t in tracers]
typed_jaxpr = core.TypedJaxpr(jaxpr_converted, consts, in_avals, out_avals)
in_unknowns = [t.pval[0] is not None for t in it.chain(env, tracers)]
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
# First, we revise the jaxpr to be staged out not to output too much.
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
# Next, we need values for the outputs that should be known. Since consts
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
# minus the residual outputs that we don't need. When `concrete=True`, as an
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
# produced concrete avals at the output, simply by using those as computed
# values. For the use case of reverse-mode ad, all the primal outputs should
# be concrete (thus not recomputed).
# values. For the use case of reverse-mode ad in op-by-op ("eager mode")
# evaluation, all the primal outputs should be concrete (thus not recomputed).
to_compute = [not uk and type(pv) is not ConcreteArray
for uk, pv in zip(out_unknowns, out_pvs)]
jaxpr_1 = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
_, in_consts = unzip2(t.pval for t in tracers)
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1)(*in_consts)[:-num_res or None]
jaxpr_1_primals = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
_, in_consts = unzip2(t.pval for t in it.chain(env, tracers))
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1_primals)(*in_consts)[:-num_res or None]
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call
# except we stage out two calls: one based on jaxpr_1 for computing the
# residuals (which in the case of reverse-mode ad involves no linear
# variables) and the other based on jaxpr_2 for evaluating everything given
# the residuals (which in reverse-mode ad is linear).
instantiated_tracers = env + instantiated_tracers
num_nonres = len(jaxpr_2.out_avals)
jaxpr_1_res = _dce_jaxpr(jaxpr_1, [False] * num_nonres + [True] * num_res,
prune_outputs=True)
const_tracers = map(trace.new_instantiated_const, consts)
bound_subjaxpr = (jaxpr, const_tracers, map(trace.full_raise, env))
bound_subjaxpr_1 = (jaxpr_1_res.jaxpr, const_tracers, ())
res_avals = jaxpr_1.out_avals[num_nonres:]
res_tracers = [JaxprTracer(trace, PartialVal((aval, unit)), None)
for aval in res_avals]
tracers_1 = [t if not uk else trace.new_instantiated_literal(unit)
for t, uk in zip(instantiated_tracers, in_unknowns)]
eqn_1 = new_eqn_recipe(tracers_1, res_tracers, remat_call_p,
(bound_subjaxpr_1,), params)
for t in res_tracers: t.recipe = eqn_1
bound_subjaxpr_2 = (jaxpr_2.jaxpr, (), ())
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
(bound_subjaxpr,), params)
for t in out_tracers:
t.recipe = eqn
tracers_2 = [t if uk else trace.new_instantiated_literal(unit)
for t, uk in zip(instantiated_tracers, in_unknowns)]
eqn_2 = new_eqn_recipe(tracers_2 + res_tracers, out_tracers, remat_call_p,
(bound_subjaxpr_2,), params)
for t in out_tracers: t.recipe = eqn_2
return out_tracers
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
# NOTE to future self: the problem with the above strategy is that the jaxpr
# produced wouldn't be round-trippable, in the sense that by forming two remat
# calls we ensured the first one would be partial-eval'd away when we tried to
# round-trip e.g. for partial eval of scan.
def _dce_jaxpr(typed_jaxpr, outputs):
def _dce_jaxpr(typed_jaxpr, outputs, prune_outputs=False):
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan or other higher-order primitives.
# nontrivially DCE through scan, call, or other higher-order primitives.
# TODO(mattjj): better DCE
jaxpr = typed_jaxpr.jaxpr
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
for var, aval, output in zip(outvars, out_avals, outputs)]
if prune_outputs:
out_pairs = [(var, aval) for var, aval, output
in zip(outvars, out_avals, outputs) if output]
else:
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
for var, aval, output in zip(outvars, out_avals, outputs)]
new_outvars, new_out_avals = unzip2(out_pairs)
needed_vars = set(new_outvars)
@ -555,3 +592,20 @@ def _reconstruct_pval(pval1, const2, unknown):
return PartialVal((None, pv1.val))
else:
return PartialVal((None, const2))
def move_binders_to_front(typed_jaxpr, to_move):
assert not typed_jaxpr.jaxpr.constvars and not typed_jaxpr.jaxpr.freevars
assert len(typed_jaxpr.in_avals) == len(to_move)
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
new_jaxpr = core.Jaxpr((), (), new_invars, typed_jaxpr.jaxpr.outvars,
typed_jaxpr.jaxpr.eqns)
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
new_in_avals, typed_jaxpr.out_avals)
return new_typed_jaxpr
def _move_to_front(lst, to_move):
return ([elt for elt, move in zip(lst, to_move) if move] +
[elt for elt, move in zip(lst, to_move) if not move])

View File

@ -778,14 +778,18 @@ def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_n
call_translations[pe.remat_call_p] = _remat_translation_rule
def _foil_cse(c, x):
rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)),
c.Constant(onp.array(1, dtype=onp.float32)),
[])
pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max))
xla_shape = c.GetShape(x)
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
return c.Select(pred, x, zero)
if xla_shape.is_tuple():
assert not xla_shape.tuple_shapes()
return x
else:
rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)),
c.Constant(onp.array(1, dtype=onp.float32)),
[])
pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max))
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
return c.Select(pred, x, zero)
### lazy constants

View File

@ -675,7 +675,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
_, _, res_pvals = split_list(out_pvals_1, [num_carry, num_ys])
intensive_residuals = [const for pv, const in res_pvals if pv is None]
move = [False] * len(jaxpr_1.in_avals) + [pv is None for pv, _ in res_pvals]
jaxpr_2_opt = _move_binders_to_front(jaxpr_2, move)
jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move)
num_consts_2 = num_consts + len(intensive_residuals)
in_consts = (list(consts_1) + [core.unit] * num_consts +
@ -712,21 +712,6 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
for t in out_tracers: t.recipe = eqn
return out_tracers
def _move_binders_to_front(typed_jaxpr, to_move):
assert not typed_jaxpr.jaxpr.constvars and not typed_jaxpr.jaxpr.freevars
assert len(typed_jaxpr.in_avals) == len(to_move)
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
new_jaxpr = core.Jaxpr((), (), new_invars, typed_jaxpr.jaxpr.outvars,
typed_jaxpr.jaxpr.eqns)
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
new_in_avals, typed_jaxpr.out_avals)
return new_typed_jaxpr
def _move_to_front(lst, to_move):
return ([elt for elt, move in zip(lst, to_move) if move] +
[elt for elt, move in zip(lst, to_move) if not move])
def _promote_aval_rank(sz, aval):
if aval is core.abstract_unit:
return core.abstract_unit

View File

@ -1329,26 +1329,56 @@ class APITest(jtu.JaxTestCase):
def test_remat_basic(self):
@api.remat
def g(x):
return lax.sin(x), 3.
return lax.sin(lax.sin(x)), 3.
def f(x):
x, _ = g(x)
return x
ans = f(2.)
expected = onp.sin(2.)
expected = onp.sin(onp.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans, f_lin = api.linearize(f, 2.)
expected = onp.sin(2.)
expected = onp.sin(onp.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = f_lin(3.)
expected = onp.cos(2.) * 3.
expected = onp.cos(onp.sin(2.)) * onp.cos(2.) * 3.
self.assertAllClose(ans, expected, check_dtypes=False)
jaxpr = api.make_jaxpr(f_lin)(3.)
self.assertIn('sin', str(jaxpr))
sin_calls = []
cos_calls = []
sin_impl = lax.sin_p.impl
cos_impl = lax.cos_p.impl
try:
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
f_lin(3.)
finally:
lax.sin_p.def_impl(sin_impl)
lax.cos_p.def_impl(cos_impl)
self.assertEqual(len(sin_calls), 1)
self.assertEqual(len(cos_calls), 2)
def test_remat_freevars(self):
def f1(x):
y = 2 * np.sin(x)
z = np.cos(x) * np.sin(y)
return z
def f2(x):
y = 2 * np.sin(x)
z = api.remat(lambda x: np.cos(x) * np.sin(y))(x)
return z
ans, f_lin = api.linearize(f2, 2.)
expected, f_lin_expected = api.linearize(f1, 2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = f_lin(3.)
expected = f_lin_expected(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_grad_python_control_flow(self):
@partial(api.remat, concrete=True)
@ -1440,6 +1470,7 @@ class APITest(jtu.JaxTestCase):
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.eqns
import ipdb; ipdb.set_trace()
self.assertIn(' sin ', str(scan_eqn.params['jaxpr']))
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)