Made some subset of vjp/jvp inputs static in quickercheck. Exposing bugs.

This commit is contained in:
Dougal Maclaurin 2018-12-03 09:52:19 -05:00
parent f1d7ea8972
commit 2f44eba01d

View File

@ -183,6 +183,16 @@ def check_close(x, y, tol=1e-3):
assert np.allclose(x, y, rtol=tol, atol=tol), \
"Value mismatch:\n{}\n vs\n{}\n".format(x, y)
def partial_argnums(f, args, dyn_argnums):
fixed_args = [None if i in dyn_argnums else arg for i, arg in enumerate(args)]
def f_(*dyn_args):
args = fixed_args[:]
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
return f(*args)
dyn_args = [args[i] for i in dyn_argnums]
return f_, dyn_args
def jit_is_identity(fun):
vals = gen_vals(fun.in_vars)
@ -196,8 +206,9 @@ def jvp_matches_fd(fun):
vals = gen_vals(fun.in_vars)
tangents = gen_vals(fun.in_vars)
fun = partial(eval_fun, fun)
# TODO: differentiate wrt some inputs only
dyn_argnums = thin(range(len(vals)), 0.5)
tangents = [tangents[i] for i in dyn_argnums]
fun, vals = partial_argnums(fun, vals, dyn_argnums)
ans1, deriv1 = jvp_fd(fun, vals, tangents)
ans2, deriv2 = jvp(fun, vals, tangents)
check_all_close(ans1, ans2)
@ -210,8 +221,9 @@ def vjp_matches_fd(fun):
in_tangents = gen_vals(fun.in_vars)
in_cotangents = gen_vals(fun.out_vars)
fun = partial(eval_fun, fun)
# TODO: differentiate wrt some inputs only
dyn_argnums = thin(range(len(vals)), 0.5)
in_tangents = [in_tangents[i] for i in dyn_argnums]
fun, vals = partial_argnums(fun, vals, dyn_argnums)
ans1, out_tangents = jvp_fd(fun, vals, in_tangents)
ans2, vjpfun = vjp(fun, *vals)
out_cotangents = vjpfun(in_cotangents)
@ -235,6 +247,7 @@ def run_tests():
for i, (size, _, check_prop) in enumerate(cases):
sys.stderr.write('\rTested: {}'.format(i))
check_prop(gen_fun_and_types(size))
print "\nok"
if __name__ == "__main__":
run_tests()