Added reverse-mode checks

This commit is contained in:
Dougal Maclaurin 2018-11-30 17:14:27 -05:00
parent 6fcee12cec
commit f1d7ea8972

View File

@ -2,7 +2,7 @@ from collections import namedtuple
from functools import partial
import numpy.random as npr
import jax.numpy as np
from jax import jit, jvp
from jax import jit, jvp, vjp
import itertools as it
import sys
@ -152,14 +152,19 @@ def gen_subset(xs):
return gen_sized_subset(xs, npr.randint(len(xs) + 1))
def gen_inputs(fun):
return [gen_array_val(v.vartype) for v in fun.in_vars]
def gen_vals(vs):
return [gen_array_val(v.vartype) for v in vs]
def jvp_fd(fun, args, directions):
def inner_prod(xs, ys):
xys = zip(xs, ys)
assert all(x.shape == y.shape for x, y in xys)
return sum(np.sum(x * y) for x, y in xys)
def jvp_fd(fun, args, tangents):
EPS = 1e-4
def eval_eps(eps):
return fun(*[x if d is None else x + eps * d
for x, d in zip(args, directions)])
return fun(*[x if t is None else x + eps * t
for x, t in zip(args, tangents)])
ys_neg = eval_eps(-EPS)
ys_pos = eval_eps(EPS)
@ -169,14 +174,18 @@ def jvp_fd(fun, args, directions):
def check_all_close(xs, ys, tol=1e-3):
for x, y in zip(xs, ys):
assert x.shape == y.shape
# TODO(dougalm): re-enable once we've tackled the less pendantic bugs
# assert x.dtype == y.dtype
assert np.allclose(x, y, rtol=tol, atol=tol), \
"Value mismatch:\n{}\n vs\n{}\n".format(x, y)
check_close(x, y, tol)
def check_close(x, y, tol=1e-3):
assert np.shape(x) == np.shape(y)
# TODO(dougalm): re-enable once we've tackled the less pendantic bugs
# assert x.dtype == y.dtype
assert np.allclose(x, y, rtol=tol, atol=tol), \
"Value mismatch:\n{}\n vs\n{}\n".format(x, y)
def jit_is_identity(fun):
vals = gen_inputs(fun)
vals = gen_vals(fun.in_vars)
fun = partial(eval_fun, fun)
ans = fun(*vals)
static_argnums = thin(range(len(vals)), 0.5)
@ -184,21 +193,39 @@ def jit_is_identity(fun):
check_all_close(ans, ans_jitted)
def jvp_matches_fd(fun):
vals = gen_inputs(fun)
directions = gen_inputs(fun)
vals = gen_vals(fun.in_vars)
tangents = gen_vals(fun.in_vars)
fun = partial(eval_fun, fun)
# TODO: differentiate wrt some inputs only
ans1, deriv1 = jvp_fd(fun, vals, directions)
ans2, deriv2 = jvp(fun, vals, directions)
ans1, deriv1 = jvp_fd(fun, vals, tangents)
ans2, deriv2 = jvp(fun, vals, tangents)
check_all_close(ans1, ans2)
check_all_close(deriv1, deriv2)
def vjp_matches_fd(fun):
# print fun
vals = gen_vals(fun.in_vars)
in_tangents = gen_vals(fun.in_vars)
in_cotangents = gen_vals(fun.out_vars)
fun = partial(eval_fun, fun)
# TODO: differentiate wrt some inputs only
ans1, out_tangents = jvp_fd(fun, vals, in_tangents)
ans2, vjpfun = vjp(fun, *vals)
out_cotangents = vjpfun(in_cotangents)
check_all_close(ans1, ans2)
inner_prod_fd = inner_prod(out_tangents, in_cotangents)
inner_prod_ad = inner_prod(in_tangents, out_cotangents)
check_close(inner_prod_fd, inner_prod_ad)
properties = [
jit_is_identity,
jvp_matches_fd,
vjp_matches_fd,
]
# vjp_matches_fd,
# vmap_matches_map ]
def run_tests():