mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
make remat_call partial-eval into one remat_call
This commit is contained in:
parent
b2b5049eb5
commit
ac251046fc
@ -187,18 +187,21 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
else:
|
||||
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
||||
assert not any(is_linear(v) for v in const_vars)
|
||||
if any(is_linear(v) for v in it.chain(eqn.invars, bound_vars)):
|
||||
linear_eqns.append(eqn)
|
||||
else:
|
||||
assert not any(is_linear(v) for v in const_vars)
|
||||
sub_consts = map(read_primal, const_vars)
|
||||
sub_freevar_vals = map(read_primal, bound_vars)
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
|
||||
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
|
||||
ans = tree_unflatten(out_tree(), out_flat)
|
||||
elif eqn.primitive is not pe.remat_call_p:
|
||||
ans = _eval_subjaxpr_primals(
|
||||
eqn.primitive, subjaxpr, map(read_primal, const_vars),
|
||||
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
|
||||
# we special-case remat_call here because it can be mixed linear /
|
||||
# nonlinear, so we always evaluate it even if it has a linear part
|
||||
if eqn.primitive is pe.remat_call_p:
|
||||
ans = _eval_subjaxpr_primals(
|
||||
eqn.primitive, subjaxpr, map(read_primal, const_vars),
|
||||
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
|
||||
ct_env = {}
|
||||
@ -225,6 +228,13 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
||||
cotangents_out = map(read_cotangent, jaxpr.invars)
|
||||
return freevar_cts, cotangents_out
|
||||
|
||||
def _eval_subjaxpr_primals(prim, jaxpr, consts, freevar_vals, in_vals, params):
|
||||
all_args, in_tree_def = tree_flatten((consts, freevar_vals, in_vals))
|
||||
fun = hashable_partial(wrap_init(_eval_primals), jaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = prim.bind(fun, *all_args, **params)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
def _eval_primals(jaxpr, consts, freevar_vals, args):
|
||||
primal_env = {}
|
||||
|
||||
@ -259,15 +269,13 @@ def _eval_primals(jaxpr, consts, freevar_vals, args):
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
else:
|
||||
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
||||
sub_consts = map(read_primal, const_vars)
|
||||
sub_freevar_vals = map(read_primal, bound_vars)
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
|
||||
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
|
||||
ans = tree_unflatten(out_tree(), out_flat)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
assert not any(is_linear(v) for v in const_vars)
|
||||
if (eqn.primitive is pe.remat_call_p or
|
||||
not any(is_linear(v) for v in it.chain(eqn.invars, bound_vars))):
|
||||
ans = _eval_subjaxpr_primals(
|
||||
eqn.primitive, subjaxpr, map(read_primal, const_vars),
|
||||
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
return map(read_primal, jaxpr.outvars)
|
||||
|
||||
class UndefinedPrimal(object):
|
||||
|
@ -507,6 +507,9 @@ def _remat_partial_eval(trace, f, tracers, params):
|
||||
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
|
||||
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
|
||||
|
||||
# First, we prune the jaxpr to be staged out not to have too many outputs.
|
||||
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
|
||||
|
||||
# Next, we need values for the outputs that should be known. Since consts
|
||||
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
|
||||
# minus the residual outputs that we don't need. When `concrete=True`, as an
|
||||
@ -521,53 +524,25 @@ def _remat_partial_eval(trace, f, tracers, params):
|
||||
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1_primals)(*in_consts)[:-num_res or None]
|
||||
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
|
||||
|
||||
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call
|
||||
# except we stage out two calls: one based on jaxpr_1 for computing the
|
||||
# residuals (which in the case of reverse-mode ad involves no linear
|
||||
# variables) and the other based on jaxpr_2 for evaluating everything given
|
||||
# the residuals (which in reverse-mode ad is linear).
|
||||
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
|
||||
instantiated_tracers = env + instantiated_tracers
|
||||
num_nonres = len(jaxpr_2.out_avals)
|
||||
jaxpr_1_res = _dce_jaxpr(jaxpr_1, [False] * num_nonres + [True] * num_res,
|
||||
prune_outputs=True)
|
||||
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
bound_subjaxpr_1 = (jaxpr_1_res.jaxpr, const_tracers, ())
|
||||
res_avals = jaxpr_1.out_avals[num_nonres:]
|
||||
res_tracers = [JaxprTracer(trace, PartialVal((aval, unit)), None)
|
||||
for aval in res_avals]
|
||||
tracers_1 = [t if not uk else trace.new_instantiated_literal(unit)
|
||||
for t, uk in zip(instantiated_tracers, in_unknowns)]
|
||||
eqn_1 = new_eqn_recipe(tracers_1, res_tracers, remat_call_p,
|
||||
(bound_subjaxpr_1,), params)
|
||||
for t in res_tracers: t.recipe = eqn_1
|
||||
|
||||
bound_subjaxpr_2 = (jaxpr_2.jaxpr, (), ())
|
||||
bound_subjaxpr = (typed_jaxpr.jaxpr, const_tracers, ())
|
||||
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
|
||||
tracers_2 = [t if uk else trace.new_instantiated_literal(unit)
|
||||
for t, uk in zip(instantiated_tracers, in_unknowns)]
|
||||
eqn_2 = new_eqn_recipe(tracers_2 + res_tracers, out_tracers, remat_call_p,
|
||||
(bound_subjaxpr_2,), params)
|
||||
for t in out_tracers: t.recipe = eqn_2
|
||||
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
|
||||
(bound_subjaxpr,), params)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
||||
# NOTE to future self: the problem with the above strategy is that the jaxpr
|
||||
# produced wouldn't be round-trippable, in the sense that by forming two remat
|
||||
# calls we ensured the first one would be partial-eval'd away when we tried to
|
||||
# round-trip e.g. for partial eval of scan.
|
||||
|
||||
def _dce_jaxpr(typed_jaxpr, outputs, prune_outputs=False):
|
||||
def _dce_jaxpr(typed_jaxpr, outputs):
|
||||
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
||||
# nontrivially DCE through scan, call, or other higher-order primitives.
|
||||
# TODO(mattjj): better DCE
|
||||
jaxpr = typed_jaxpr.jaxpr
|
||||
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
|
||||
if prune_outputs:
|
||||
out_pairs = [(var, aval) for var, aval, output
|
||||
in zip(outvars, out_avals, outputs) if output]
|
||||
else:
|
||||
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
|
||||
for var, aval, output in zip(outvars, out_avals, outputs)]
|
||||
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
|
||||
for var, aval, output in zip(outvars, out_avals, outputs)]
|
||||
new_outvars, new_out_avals = unzip2(out_pairs)
|
||||
|
||||
needed_vars = set(new_outvars)
|
||||
|
@ -1470,8 +1470,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.eqns
|
||||
import ipdb; ipdb.set_trace()
|
||||
self.assertIn(' sin ', str(scan_eqn.params['jaxpr']))
|
||||
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.eqns
|
||||
|
Loading…
x
Reference in New Issue
Block a user