mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
try remat_call partial-eval into two remat_calls
The idea here was for the resulting jaxpr to have a purely nonlinear remat_call and a linear one with no primals to evaluate. (I wanted to avoid having to recurse into all calls in _eval_primal in backward_pass.) But the issue is that makes jaxprs not round-trippable, since the first remat_call, depending only on constants, would get partial-eval'd away at the first attempted round-trip. And we round-trip in partial_eval_jaxpr, particularly for partial eval of scan. That meant remat of scan didn't work, and that's no good!
This commit is contained in:
parent
9a8523603c
commit
b2b5049eb5
@ -187,17 +187,19 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
else:
|
||||
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
||||
if any(is_linear(v) for v in it.chain(eqn.invars, const_vars, bound_vars)):
|
||||
if any(is_linear(v) for v in it.chain(eqn.invars, bound_vars)):
|
||||
linear_eqns.append(eqn)
|
||||
sub_consts = map(read_primal, const_vars)
|
||||
sub_freevar_vals = map(read_primal, bound_vars)
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
|
||||
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
|
||||
ans = tree_unflatten(out_tree(), out_flat)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
else:
|
||||
assert not any(is_linear(v) for v in const_vars)
|
||||
sub_consts = map(read_primal, const_vars)
|
||||
sub_freevar_vals = map(read_primal, bound_vars)
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
|
||||
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
|
||||
ans = tree_unflatten(out_tree(), out_flat)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
|
||||
ct_env = {}
|
||||
map(write_cotangent, jaxpr.outvars, cotangents_in)
|
||||
|
@ -411,6 +411,14 @@ def closure_convert_jaxpr(jaxpr):
|
||||
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
||||
return lifted_jaxpr
|
||||
|
||||
def convert_freevars_jaxpr(jaxpr):
|
||||
core.skip_checks or core.check_jaxpr(jaxpr)
|
||||
lifted_jaxpr = Jaxpr(constvars=jaxpr.constvars, freevars=(),
|
||||
invars=jaxpr.freevars + jaxpr.invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
||||
return lifted_jaxpr
|
||||
|
||||
def partial_eval_jaxpr(jaxpr, unknowns, instantiate):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
|
||||
@ -483,54 +491,83 @@ def _remat_partial_eval(trace, f, tracers, params):
|
||||
fun, aux = partial_eval(f, trace, in_pvs)
|
||||
out_flat = remat_call_p.bind(fun, *in_consts, **params)
|
||||
out_pvs, jaxpr, env = aux()
|
||||
env = map(trace.full_raise, env)
|
||||
out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]
|
||||
|
||||
# Since we traced with everything marked as unknown, but we need to know which
|
||||
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
|
||||
in_avals = [raise_to_shaped(pv) for pv in in_pvs]
|
||||
jaxpr_converted = convert_freevars_jaxpr(jaxpr)
|
||||
in_avals = ([raise_to_shaped(t.pval[0]) for t in env]
|
||||
+ [raise_to_shaped(pv) for pv in in_pvs])
|
||||
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
|
||||
for pv, const in zip(out_pvs, out_pval_consts1)]
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
||||
in_unknowns = [t.pval[0] is not None for t in tracers]
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr_converted, consts, in_avals, out_avals)
|
||||
in_unknowns = [t.pval[0] is not None for t in it.chain(env, tracers)]
|
||||
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
|
||||
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
|
||||
|
||||
# First, we revise the jaxpr to be staged out not to output too much.
|
||||
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
|
||||
|
||||
# Next, we need values for the outputs that should be known. Since consts
|
||||
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
|
||||
# minus the residual outputs that we don't need. When `concrete=True`, as an
|
||||
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
|
||||
# produced concrete avals at the output, simply by using those as computed
|
||||
# values. For the use case of reverse-mode ad, all the primal outputs should
|
||||
# be concrete (thus not recomputed).
|
||||
# values. For the use case of reverse-mode ad in op-by-op ("eager mode")
|
||||
# evaluation, all the primal outputs should be concrete (thus not recomputed).
|
||||
to_compute = [not uk and type(pv) is not ConcreteArray
|
||||
for uk, pv in zip(out_unknowns, out_pvs)]
|
||||
jaxpr_1 = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
|
||||
_, in_consts = unzip2(t.pval for t in tracers)
|
||||
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1)(*in_consts)[:-num_res or None]
|
||||
jaxpr_1_primals = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
|
||||
_, in_consts = unzip2(t.pval for t in it.chain(env, tracers))
|
||||
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1_primals)(*in_consts)[:-num_res or None]
|
||||
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
|
||||
|
||||
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
|
||||
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call
|
||||
# except we stage out two calls: one based on jaxpr_1 for computing the
|
||||
# residuals (which in the case of reverse-mode ad involves no linear
|
||||
# variables) and the other based on jaxpr_2 for evaluating everything given
|
||||
# the residuals (which in reverse-mode ad is linear).
|
||||
instantiated_tracers = env + instantiated_tracers
|
||||
num_nonres = len(jaxpr_2.out_avals)
|
||||
jaxpr_1_res = _dce_jaxpr(jaxpr_1, [False] * num_nonres + [True] * num_res,
|
||||
prune_outputs=True)
|
||||
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
bound_subjaxpr = (jaxpr, const_tracers, map(trace.full_raise, env))
|
||||
bound_subjaxpr_1 = (jaxpr_1_res.jaxpr, const_tracers, ())
|
||||
res_avals = jaxpr_1.out_avals[num_nonres:]
|
||||
res_tracers = [JaxprTracer(trace, PartialVal((aval, unit)), None)
|
||||
for aval in res_avals]
|
||||
tracers_1 = [t if not uk else trace.new_instantiated_literal(unit)
|
||||
for t, uk in zip(instantiated_tracers, in_unknowns)]
|
||||
eqn_1 = new_eqn_recipe(tracers_1, res_tracers, remat_call_p,
|
||||
(bound_subjaxpr_1,), params)
|
||||
for t in res_tracers: t.recipe = eqn_1
|
||||
|
||||
bound_subjaxpr_2 = (jaxpr_2.jaxpr, (), ())
|
||||
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
|
||||
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
|
||||
(bound_subjaxpr,), params)
|
||||
for t in out_tracers:
|
||||
t.recipe = eqn
|
||||
tracers_2 = [t if uk else trace.new_instantiated_literal(unit)
|
||||
for t, uk in zip(instantiated_tracers, in_unknowns)]
|
||||
eqn_2 = new_eqn_recipe(tracers_2 + res_tracers, out_tracers, remat_call_p,
|
||||
(bound_subjaxpr_2,), params)
|
||||
for t in out_tracers: t.recipe = eqn_2
|
||||
return out_tracers
|
||||
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
||||
# NOTE to future self: the problem with the above strategy is that the jaxpr
|
||||
# produced wouldn't be round-trippable, in the sense that by forming two remat
|
||||
# calls we ensured the first one would be partial-eval'd away when we tried to
|
||||
# round-trip e.g. for partial eval of scan.
|
||||
|
||||
def _dce_jaxpr(typed_jaxpr, outputs):
|
||||
def _dce_jaxpr(typed_jaxpr, outputs, prune_outputs=False):
|
||||
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
||||
# nontrivially DCE through scan or other higher-order primitives.
|
||||
# nontrivially DCE through scan, call, or other higher-order primitives.
|
||||
# TODO(mattjj): better DCE
|
||||
jaxpr = typed_jaxpr.jaxpr
|
||||
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
|
||||
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
|
||||
for var, aval, output in zip(outvars, out_avals, outputs)]
|
||||
if prune_outputs:
|
||||
out_pairs = [(var, aval) for var, aval, output
|
||||
in zip(outvars, out_avals, outputs) if output]
|
||||
else:
|
||||
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
|
||||
for var, aval, output in zip(outvars, out_avals, outputs)]
|
||||
new_outvars, new_out_avals = unzip2(out_pairs)
|
||||
|
||||
needed_vars = set(new_outvars)
|
||||
@ -555,3 +592,20 @@ def _reconstruct_pval(pval1, const2, unknown):
|
||||
return PartialVal((None, pv1.val))
|
||||
else:
|
||||
return PartialVal((None, const2))
|
||||
|
||||
|
||||
def move_binders_to_front(typed_jaxpr, to_move):
|
||||
assert not typed_jaxpr.jaxpr.constvars and not typed_jaxpr.jaxpr.freevars
|
||||
assert len(typed_jaxpr.in_avals) == len(to_move)
|
||||
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
|
||||
new_jaxpr = core.Jaxpr((), (), new_invars, typed_jaxpr.jaxpr.outvars,
|
||||
typed_jaxpr.jaxpr.eqns)
|
||||
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
|
||||
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
|
||||
new_in_avals, typed_jaxpr.out_avals)
|
||||
return new_typed_jaxpr
|
||||
|
||||
def _move_to_front(lst, to_move):
|
||||
return ([elt for elt, move in zip(lst, to_move) if move] +
|
||||
[elt for elt, move in zip(lst, to_move) if not move])
|
||||
|
||||
|
@ -778,14 +778,18 @@ def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_n
|
||||
call_translations[pe.remat_call_p] = _remat_translation_rule
|
||||
|
||||
def _foil_cse(c, x):
|
||||
rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)),
|
||||
c.Constant(onp.array(1, dtype=onp.float32)),
|
||||
[])
|
||||
pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max))
|
||||
xla_shape = c.GetShape(x)
|
||||
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
|
||||
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
|
||||
return c.Select(pred, x, zero)
|
||||
if xla_shape.is_tuple():
|
||||
assert not xla_shape.tuple_shapes()
|
||||
return x
|
||||
else:
|
||||
rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)),
|
||||
c.Constant(onp.array(1, dtype=onp.float32)),
|
||||
[])
|
||||
pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max))
|
||||
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
|
||||
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
|
||||
return c.Select(pred, x, zero)
|
||||
|
||||
|
||||
### lazy constants
|
||||
|
@ -675,7 +675,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
|
||||
_, _, res_pvals = split_list(out_pvals_1, [num_carry, num_ys])
|
||||
intensive_residuals = [const for pv, const in res_pvals if pv is None]
|
||||
move = [False] * len(jaxpr_1.in_avals) + [pv is None for pv, _ in res_pvals]
|
||||
jaxpr_2_opt = _move_binders_to_front(jaxpr_2, move)
|
||||
jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move)
|
||||
num_consts_2 = num_consts + len(intensive_residuals)
|
||||
|
||||
in_consts = (list(consts_1) + [core.unit] * num_consts +
|
||||
@ -712,21 +712,6 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
def _move_binders_to_front(typed_jaxpr, to_move):
|
||||
assert not typed_jaxpr.jaxpr.constvars and not typed_jaxpr.jaxpr.freevars
|
||||
assert len(typed_jaxpr.in_avals) == len(to_move)
|
||||
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
|
||||
new_jaxpr = core.Jaxpr((), (), new_invars, typed_jaxpr.jaxpr.outvars,
|
||||
typed_jaxpr.jaxpr.eqns)
|
||||
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
|
||||
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
|
||||
new_in_avals, typed_jaxpr.out_avals)
|
||||
return new_typed_jaxpr
|
||||
|
||||
def _move_to_front(lst, to_move):
|
||||
return ([elt for elt, move in zip(lst, to_move) if move] +
|
||||
[elt for elt, move in zip(lst, to_move) if not move])
|
||||
|
||||
def _promote_aval_rank(sz, aval):
|
||||
if aval is core.abstract_unit:
|
||||
return core.abstract_unit
|
||||
|
@ -1329,26 +1329,56 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_remat_basic(self):
|
||||
@api.remat
|
||||
def g(x):
|
||||
return lax.sin(x), 3.
|
||||
return lax.sin(lax.sin(x)), 3.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
ans = f(2.)
|
||||
expected = onp.sin(2.)
|
||||
expected = onp.sin(onp.sin(2.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans, f_lin = api.linearize(f, 2.)
|
||||
expected = onp.sin(2.)
|
||||
expected = onp.sin(onp.sin(2.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = f_lin(3.)
|
||||
expected = onp.cos(2.) * 3.
|
||||
expected = onp.cos(onp.sin(2.)) * onp.cos(2.) * 3.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
jaxpr = api.make_jaxpr(f_lin)(3.)
|
||||
self.assertIn('sin', str(jaxpr))
|
||||
sin_calls = []
|
||||
cos_calls = []
|
||||
sin_impl = lax.sin_p.impl
|
||||
cos_impl = lax.cos_p.impl
|
||||
try:
|
||||
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
|
||||
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
|
||||
f_lin(3.)
|
||||
finally:
|
||||
lax.sin_p.def_impl(sin_impl)
|
||||
lax.cos_p.def_impl(cos_impl)
|
||||
self.assertEqual(len(sin_calls), 1)
|
||||
self.assertEqual(len(cos_calls), 2)
|
||||
|
||||
def test_remat_freevars(self):
|
||||
def f1(x):
|
||||
y = 2 * np.sin(x)
|
||||
z = np.cos(x) * np.sin(y)
|
||||
return z
|
||||
|
||||
def f2(x):
|
||||
y = 2 * np.sin(x)
|
||||
z = api.remat(lambda x: np.cos(x) * np.sin(y))(x)
|
||||
return z
|
||||
|
||||
ans, f_lin = api.linearize(f2, 2.)
|
||||
expected, f_lin_expected = api.linearize(f1, 2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = f_lin(3.)
|
||||
expected = f_lin_expected(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_grad_python_control_flow(self):
|
||||
@partial(api.remat, concrete=True)
|
||||
@ -1440,6 +1470,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.eqns
|
||||
import ipdb; ipdb.set_trace()
|
||||
self.assertIn(' sin ', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
|
||||
|
Loading…
x
Reference in New Issue
Block a user