mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added reverse-mode checks
This commit is contained in:
parent
6fcee12cec
commit
f1d7ea8972
@ -2,7 +2,7 @@ from collections import namedtuple
|
||||
from functools import partial
|
||||
import numpy.random as npr
|
||||
import jax.numpy as np
|
||||
from jax import jit, jvp
|
||||
from jax import jit, jvp, vjp
|
||||
import itertools as it
|
||||
import sys
|
||||
|
||||
@ -152,14 +152,19 @@ def gen_subset(xs):
|
||||
|
||||
return gen_sized_subset(xs, npr.randint(len(xs) + 1))
|
||||
|
||||
def gen_inputs(fun):
|
||||
return [gen_array_val(v.vartype) for v in fun.in_vars]
|
||||
def gen_vals(vs):
|
||||
return [gen_array_val(v.vartype) for v in vs]
|
||||
|
||||
def jvp_fd(fun, args, directions):
|
||||
def inner_prod(xs, ys):
|
||||
xys = zip(xs, ys)
|
||||
assert all(x.shape == y.shape for x, y in xys)
|
||||
return sum(np.sum(x * y) for x, y in xys)
|
||||
|
||||
def jvp_fd(fun, args, tangents):
|
||||
EPS = 1e-4
|
||||
def eval_eps(eps):
|
||||
return fun(*[x if d is None else x + eps * d
|
||||
for x, d in zip(args, directions)])
|
||||
return fun(*[x if t is None else x + eps * t
|
||||
for x, t in zip(args, tangents)])
|
||||
|
||||
ys_neg = eval_eps(-EPS)
|
||||
ys_pos = eval_eps(EPS)
|
||||
@ -169,14 +174,18 @@ def jvp_fd(fun, args, directions):
|
||||
|
||||
def check_all_close(xs, ys, tol=1e-3):
|
||||
for x, y in zip(xs, ys):
|
||||
assert x.shape == y.shape
|
||||
# TODO(dougalm): re-enable once we've tackled the less pendantic bugs
|
||||
# assert x.dtype == y.dtype
|
||||
assert np.allclose(x, y, rtol=tol, atol=tol), \
|
||||
"Value mismatch:\n{}\n vs\n{}\n".format(x, y)
|
||||
check_close(x, y, tol)
|
||||
|
||||
def check_close(x, y, tol=1e-3):
|
||||
assert np.shape(x) == np.shape(y)
|
||||
# TODO(dougalm): re-enable once we've tackled the less pendantic bugs
|
||||
# assert x.dtype == y.dtype
|
||||
assert np.allclose(x, y, rtol=tol, atol=tol), \
|
||||
"Value mismatch:\n{}\n vs\n{}\n".format(x, y)
|
||||
|
||||
|
||||
def jit_is_identity(fun):
|
||||
vals = gen_inputs(fun)
|
||||
vals = gen_vals(fun.in_vars)
|
||||
fun = partial(eval_fun, fun)
|
||||
ans = fun(*vals)
|
||||
static_argnums = thin(range(len(vals)), 0.5)
|
||||
@ -184,21 +193,39 @@ def jit_is_identity(fun):
|
||||
check_all_close(ans, ans_jitted)
|
||||
|
||||
def jvp_matches_fd(fun):
|
||||
vals = gen_inputs(fun)
|
||||
directions = gen_inputs(fun)
|
||||
vals = gen_vals(fun.in_vars)
|
||||
tangents = gen_vals(fun.in_vars)
|
||||
fun = partial(eval_fun, fun)
|
||||
|
||||
# TODO: differentiate wrt some inputs only
|
||||
ans1, deriv1 = jvp_fd(fun, vals, directions)
|
||||
ans2, deriv2 = jvp(fun, vals, directions)
|
||||
ans1, deriv1 = jvp_fd(fun, vals, tangents)
|
||||
ans2, deriv2 = jvp(fun, vals, tangents)
|
||||
check_all_close(ans1, ans2)
|
||||
check_all_close(deriv1, deriv2)
|
||||
|
||||
|
||||
def vjp_matches_fd(fun):
|
||||
# print fun
|
||||
vals = gen_vals(fun.in_vars)
|
||||
in_tangents = gen_vals(fun.in_vars)
|
||||
in_cotangents = gen_vals(fun.out_vars)
|
||||
fun = partial(eval_fun, fun)
|
||||
|
||||
# TODO: differentiate wrt some inputs only
|
||||
ans1, out_tangents = jvp_fd(fun, vals, in_tangents)
|
||||
ans2, vjpfun = vjp(fun, *vals)
|
||||
out_cotangents = vjpfun(in_cotangents)
|
||||
check_all_close(ans1, ans2)
|
||||
inner_prod_fd = inner_prod(out_tangents, in_cotangents)
|
||||
inner_prod_ad = inner_prod(in_tangents, out_cotangents)
|
||||
check_close(inner_prod_fd, inner_prod_ad)
|
||||
|
||||
|
||||
properties = [
|
||||
jit_is_identity,
|
||||
jvp_matches_fd,
|
||||
vjp_matches_fd,
|
||||
]
|
||||
# vjp_matches_fd,
|
||||
# vmap_matches_map ]
|
||||
|
||||
def run_tests():
|
||||
|
Loading…
x
Reference in New Issue
Block a user