Fixed bug in vjp with constant-zero tangent outputs

This commit is contained in:
Dougal Maclaurin 2018-12-03 22:24:46 -05:00
parent 2f44eba01d
commit 2e4ff400fa
3 changed files with 36 additions and 13 deletions

View File

@ -123,8 +123,8 @@ def lift_linearized(jaxpr, consts, io_tree, out_pval, py_args):
def fun(*args):
primals = pack(args) # doesn't matter what these are-they'll be ignored
tangents = pack(args)
ans = eval_jaxpr(jaxpr, consts, (), primals, tangents)
return list(pe.merge_pvals(ans, out_pval))[1]
_, ans = eval_jaxpr(jaxpr, consts, (), primals, tangents)
return pe.merge_pvals(ans, out_pval)
return unflatten_fun(fun, io_tree, *py_args)

View File

@ -64,21 +64,40 @@ def linearize(traceable, *primals):
in_pvals = (pe.PartialVal((None, pack(primals))),
pe.PartialVal((core.AbstractTuple(tangent_avals), core.unit)))
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jvpfun, in_pvals)
out_pv, out_const = out_pval
assert out_pv is None or out_pv[0] is None
primal_out = tuple(out_const)[0]
return primal_out, out_pval, jaxpr, consts
pval_primal, pval_tangent = unpair_pval(out_pval)
aval_primal, const_primal = pval_primal
assert aval_primal is None
return const_primal, pval_tangent, jaxpr, consts
def vjp(traceable, primals):
out_primal, _, jaxpr, consts = linearize(traceable, *primals)
out_primal, pval, jaxpr, consts = linearize(traceable, *primals)
def vjp_(ct):
ct = ignore_consts(ct, pval)
dummy_primal_and_ct = pack((core.unit, ct))
_, arg_cts = backward_pass(jaxpr, consts, (), dummy_primal_and_ct)
return instantiate_zeros(pack(primals), arg_cts[1])
return out_primal, vjp_
def ignore_consts(ct, pval):
aval, const = pval
if isinstance(aval, core.AbstractValue):
return ct
elif isinstance(aval, pe.JaxprTracerTuple):
return pack(map(ignore_consts, ct, zip(aval, const)))
elif aval is None:
return core.unit
else:
raise TypeError(aval)
def unpair_pval(pval):
aval, const = pval
const_1, const_2 = const
if aval is None:
return (None, const_1), (None, const_2)
else:
aval_1, aval_2 = aval
return (aval_1, const_1), (aval_2, const_2)
def backward_pass(jaxpr, consts, freevar_vals, cotangent_in):
def write_cotangent(v, ct):
@ -116,8 +135,8 @@ def backward_pass(jaxpr, consts, freevar_vals, cotangent_in):
if cts_out is zero:
cts_out = [zero for _ in eqn.invars]
# TODO(phawkins,dougalm): eqn.invars and cts_out can have different lengths
for var, ct in builtins.zip(eqn.invars, cts_out):
for var, ct in zip(eqn.invars, cts_out):
write_cotangent(var, ct)
cotangents_out = map(read_cotangent, jaxpr.invars)

View File

@ -216,7 +216,6 @@ def jvp_matches_fd(fun):
def vjp_matches_fd(fun):
# print fun
vals = gen_vals(fun.in_vars)
in_tangents = gen_vals(fun.in_vars)
in_cotangents = gen_vals(fun.out_vars)
@ -232,7 +231,6 @@ def vjp_matches_fd(fun):
inner_prod_ad = inner_prod(in_tangents, out_cotangents)
check_close(inner_prod_fd, inner_prod_ad)
properties = [
jit_is_identity,
jvp_matches_fd,
@ -246,7 +244,13 @@ def run_tests():
cases = it.product(sizes, range(num_examples), properties)
for i, (size, _, check_prop) in enumerate(cases):
sys.stderr.write('\rTested: {}'.format(i))
check_prop(gen_fun_and_types(size))
try:
fun = gen_fun_and_types(size)
check_prop(fun)
except:
print fun
raise
print "\nok"
if __name__ == "__main__":