mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Made some subset of vjp/jvp inputs static in quickercheck. Exposing bugs.
This commit is contained in:
parent
f1d7ea8972
commit
2f44eba01d
@ -183,6 +183,16 @@ def check_close(x, y, tol=1e-3):
|
||||
assert np.allclose(x, y, rtol=tol, atol=tol), \
|
||||
"Value mismatch:\n{}\n vs\n{}\n".format(x, y)
|
||||
|
||||
def partial_argnums(f, args, dyn_argnums):
|
||||
fixed_args = [None if i in dyn_argnums else arg for i, arg in enumerate(args)]
|
||||
def f_(*dyn_args):
|
||||
args = fixed_args[:]
|
||||
for i, arg in zip(dyn_argnums, dyn_args):
|
||||
args[i] = arg
|
||||
return f(*args)
|
||||
|
||||
dyn_args = [args[i] for i in dyn_argnums]
|
||||
return f_, dyn_args
|
||||
|
||||
def jit_is_identity(fun):
|
||||
vals = gen_vals(fun.in_vars)
|
||||
@ -196,8 +206,9 @@ def jvp_matches_fd(fun):
|
||||
vals = gen_vals(fun.in_vars)
|
||||
tangents = gen_vals(fun.in_vars)
|
||||
fun = partial(eval_fun, fun)
|
||||
|
||||
# TODO: differentiate wrt some inputs only
|
||||
dyn_argnums = thin(range(len(vals)), 0.5)
|
||||
tangents = [tangents[i] for i in dyn_argnums]
|
||||
fun, vals = partial_argnums(fun, vals, dyn_argnums)
|
||||
ans1, deriv1 = jvp_fd(fun, vals, tangents)
|
||||
ans2, deriv2 = jvp(fun, vals, tangents)
|
||||
check_all_close(ans1, ans2)
|
||||
@ -210,8 +221,9 @@ def vjp_matches_fd(fun):
|
||||
in_tangents = gen_vals(fun.in_vars)
|
||||
in_cotangents = gen_vals(fun.out_vars)
|
||||
fun = partial(eval_fun, fun)
|
||||
|
||||
# TODO: differentiate wrt some inputs only
|
||||
dyn_argnums = thin(range(len(vals)), 0.5)
|
||||
in_tangents = [in_tangents[i] for i in dyn_argnums]
|
||||
fun, vals = partial_argnums(fun, vals, dyn_argnums)
|
||||
ans1, out_tangents = jvp_fd(fun, vals, in_tangents)
|
||||
ans2, vjpfun = vjp(fun, *vals)
|
||||
out_cotangents = vjpfun(in_cotangents)
|
||||
@ -235,6 +247,7 @@ def run_tests():
|
||||
for i, (size, _, check_prop) in enumerate(cases):
|
||||
sys.stderr.write('\rTested: {}'.format(i))
|
||||
check_prop(gen_fun_and_types(size))
|
||||
print "\nok"
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Loading…
x
Reference in New Issue
Block a user