mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fixed bug in vjp with constant-zero tangent outputs
This commit is contained in:
parent
2f44eba01d
commit
2e4ff400fa
@ -123,8 +123,8 @@ def lift_linearized(jaxpr, consts, io_tree, out_pval, py_args):
|
||||
def fun(*args):
|
||||
primals = pack(args) # doesn't matter what these are-they'll be ignored
|
||||
tangents = pack(args)
|
||||
ans = eval_jaxpr(jaxpr, consts, (), primals, tangents)
|
||||
return list(pe.merge_pvals(ans, out_pval))[1]
|
||||
_, ans = eval_jaxpr(jaxpr, consts, (), primals, tangents)
|
||||
return pe.merge_pvals(ans, out_pval)
|
||||
|
||||
return unflatten_fun(fun, io_tree, *py_args)
|
||||
|
||||
|
@ -64,21 +64,40 @@ def linearize(traceable, *primals):
|
||||
in_pvals = (pe.PartialVal((None, pack(primals))),
|
||||
pe.PartialVal((core.AbstractTuple(tangent_avals), core.unit)))
|
||||
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jvpfun, in_pvals)
|
||||
out_pv, out_const = out_pval
|
||||
assert out_pv is None or out_pv[0] is None
|
||||
primal_out = tuple(out_const)[0]
|
||||
return primal_out, out_pval, jaxpr, consts
|
||||
|
||||
pval_primal, pval_tangent = unpair_pval(out_pval)
|
||||
aval_primal, const_primal = pval_primal
|
||||
assert aval_primal is None
|
||||
return const_primal, pval_tangent, jaxpr, consts
|
||||
|
||||
def vjp(traceable, primals):
|
||||
out_primal, _, jaxpr, consts = linearize(traceable, *primals)
|
||||
out_primal, pval, jaxpr, consts = linearize(traceable, *primals)
|
||||
def vjp_(ct):
|
||||
ct = ignore_consts(ct, pval)
|
||||
dummy_primal_and_ct = pack((core.unit, ct))
|
||||
_, arg_cts = backward_pass(jaxpr, consts, (), dummy_primal_and_ct)
|
||||
return instantiate_zeros(pack(primals), arg_cts[1])
|
||||
|
||||
return out_primal, vjp_
|
||||
|
||||
def ignore_consts(ct, pval):
|
||||
aval, const = pval
|
||||
if isinstance(aval, core.AbstractValue):
|
||||
return ct
|
||||
elif isinstance(aval, pe.JaxprTracerTuple):
|
||||
return pack(map(ignore_consts, ct, zip(aval, const)))
|
||||
elif aval is None:
|
||||
return core.unit
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
|
||||
def unpair_pval(pval):
|
||||
aval, const = pval
|
||||
const_1, const_2 = const
|
||||
if aval is None:
|
||||
return (None, const_1), (None, const_2)
|
||||
else:
|
||||
aval_1, aval_2 = aval
|
||||
return (aval_1, const_1), (aval_2, const_2)
|
||||
|
||||
def backward_pass(jaxpr, consts, freevar_vals, cotangent_in):
|
||||
def write_cotangent(v, ct):
|
||||
@ -116,8 +135,8 @@ def backward_pass(jaxpr, consts, freevar_vals, cotangent_in):
|
||||
|
||||
if cts_out is zero:
|
||||
cts_out = [zero for _ in eqn.invars]
|
||||
# TODO(phawkins,dougalm): eqn.invars and cts_out can have different lengths
|
||||
for var, ct in builtins.zip(eqn.invars, cts_out):
|
||||
|
||||
for var, ct in zip(eqn.invars, cts_out):
|
||||
write_cotangent(var, ct)
|
||||
|
||||
cotangents_out = map(read_cotangent, jaxpr.invars)
|
||||
|
@ -216,7 +216,6 @@ def jvp_matches_fd(fun):
|
||||
|
||||
|
||||
def vjp_matches_fd(fun):
|
||||
# print fun
|
||||
vals = gen_vals(fun.in_vars)
|
||||
in_tangents = gen_vals(fun.in_vars)
|
||||
in_cotangents = gen_vals(fun.out_vars)
|
||||
@ -232,7 +231,6 @@ def vjp_matches_fd(fun):
|
||||
inner_prod_ad = inner_prod(in_tangents, out_cotangents)
|
||||
check_close(inner_prod_fd, inner_prod_ad)
|
||||
|
||||
|
||||
properties = [
|
||||
jit_is_identity,
|
||||
jvp_matches_fd,
|
||||
@ -246,7 +244,13 @@ def run_tests():
|
||||
cases = it.product(sizes, range(num_examples), properties)
|
||||
for i, (size, _, check_prop) in enumerate(cases):
|
||||
sys.stderr.write('\rTested: {}'.format(i))
|
||||
check_prop(gen_fun_and_types(size))
|
||||
try:
|
||||
fun = gen_fun_and_types(size)
|
||||
check_prop(fun)
|
||||
except:
|
||||
print fun
|
||||
raise
|
||||
|
||||
print "\nok"
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user